From 8e1417c3af7239ec42deb2d786ee082076a3d137 Mon Sep 17 00:00:00 2001 From: Phillip Berndt Date: Thu, 19 Jan 2023 09:20:50 +0100 Subject: [PATCH] Compute host header correctly Signatures need to include the host header, but the requests library does not include it in prepared requests by default. Rather, it trusts that Python's HTTP client will compute and inject it when sending the request. This forces requests-aws4auth to compute how this header will look like. A slight discrepancy between the implementations is that the code in this library unconditionally skips the port, whereas the request ending up being sent will include a port if it does not match the URL scheme's default. This change adjusts the implementations to match in that regard. Fixes #34 --- requests_aws4auth/aws4auth.py | 10 +++++++++- requests_aws4auth/test/test_requests_aws4auth.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/requests_aws4auth/aws4auth.py b/requests_aws4auth/aws4auth.py index e0b55ff..2370378 100644 --- a/requests_aws4auth/aws4auth.py +++ b/requests_aws4auth/aws4auth.py @@ -615,7 +615,15 @@ def get_canonical_headers(cls, req, include=None): # in the signed headers, but Requests doesn't include it in a # PreparedRequest if 'host' not in headers: - headers['host'] = urlparse(str(req.url)).netloc.split(':')[0] + purl = urlparse(str(req.url)) + netloc = purl.netloc + # Python's http client only includes the port if it is non-default, + # see http.client.HTTPConnection.putrequest. The request URL, on the + # other hand, might explicitly include it. + if ((purl.port == 80 and purl.scheme == 'http') + or (purl.port == 443 and purl.scheme == 'https')): + netloc = netloc.rsplit(":", 1)[0] + headers['host'] = netloc # Aggregate for upper/lowercase header name collisions in header names, # AMZ requires values of colliding headers be concatenated into a # single header with lowercase name. Although this is not possible with diff --git a/requests_aws4auth/test/test_requests_aws4auth.py b/requests_aws4auth/test/test_requests_aws4auth.py index dd25dc2..7be3134 100644 --- a/requests_aws4auth/test/test_requests_aws4auth.py +++ b/requests_aws4auth/test/test_requests_aws4auth.py @@ -949,7 +949,7 @@ def test_netloc_port(self): request. """ - req = requests.Request('GET', 'http://amazonaws.com:8443') + req = requests.Request('GET', 'https://amazonaws.com:443') preq = req.prepare() self.assertNotIn('host', preq.headers) result = AWS4Auth.get_canonical_headers(preq, include=['host'])