diff --git a/httpx/_urls.py b/httpx/_urls.py index 70486bc9e4..f6788e5568 100644 --- a/httpx/_urls.py +++ b/httpx/_urls.py @@ -484,7 +484,11 @@ def copy_with(self, **kwargs: typing.Any) -> "URL": # \_/ \______________/\_________/ \_________/ \__/ # | | | | | # scheme authority path query fragment - return URL(self._uri_reference.copy_with(**kwargs).unsplit()) + new_url = URL(self) + new_url._uri_reference = self._uri_reference.copy_with(**kwargs) + if new_url.is_absolute_url: + new_url._uri_reference = new_url._uri_reference.normalize() + return URL(new_url) def copy_set_param(self, key: str, value: typing.Any = None) -> "URL": return self.copy_with(params=self.params.set(key, value)) diff --git a/tests/models/test_url.py b/tests/models/test_url.py index cd099bd931..a088fc2a10 100644 --- a/tests/models/test_url.py +++ b/tests/models/test_url.py @@ -308,6 +308,55 @@ def test_url_copywith_raw_path(): assert url.raw_path == b"/some/path?a=123" +def test_url_copywith_security(): + """ + Prevent unexpected changes on URL after calling copy_with (CVE-2021-41945) + """ + url = httpx.URL("https://u:p@[invalid!]//evilHost/path?t=w#tw") + original_scheme = url.scheme + original_userinfo = url.userinfo + original_netloc = url.netloc + original_raw_path = url.raw_path + original_query = url.query + original_fragment = url.fragment + url = url.copy_with() + assert url.scheme == original_scheme + assert url.userinfo == original_userinfo + assert url.netloc == original_netloc + assert url.raw_path == original_raw_path + assert url.query == original_query + assert url.fragment == original_fragment + + url = httpx.URL("https://u:p@[invalid!]//evilHost/path?t=w#tw") + original_scheme = url.scheme + original_netloc = url.netloc + original_raw_path = url.raw_path + original_query = url.query + original_fragment = url.fragment + url = url.copy_with(userinfo=b"") + assert url.scheme == original_scheme + assert url.userinfo == b"" + assert url.netloc == original_netloc + assert url.raw_path == original_raw_path + assert url.query == original_query + assert url.fragment == original_fragment + + url = httpx.URL("https://example.com/path?t=w#tw") + original_userinfo = url.userinfo + original_netloc = url.netloc + original_raw_path = url.raw_path + original_query = url.query + original_fragment = url.fragment + bad = "https://xxxx:xxxx@xxxxxxx/xxxxx/xxx?x=x#xxxxx" + url = url.copy_with(scheme=bad) + assert url.scheme == bad + assert url.userinfo == original_userinfo + assert url.netloc == original_netloc + assert url.raw_path == original_raw_path + assert url.query == original_query + assert url.fragment == original_fragment + + def test_url_invalid(): with pytest.raises(httpx.InvalidURL): httpx.URL("https://😇/")