Skip to content

Commit

Permalink
Compute host header correctly
Browse files Browse the repository at this point in the history
Signatures need to include the host header, but the requests library
does not include it in prepared requests by default. Rather, it trusts
that Python's HTTP client will compute and inject it when sending the
request. This forces requests-aws4auth to compute how this header will
look like.

A slight discrepancy between the implementations is that the code in
this library unconditionally skips the port, whereas the request ending
up being sent will include a port if it does not match the URL scheme's
default.

This change adjusts the implementations to match in that regard.

Fixes #34
  • Loading branch information
phillipberndt committed Jan 20, 2023
1 parent 3b4d2da commit 8390d23
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 9 additions & 1 deletion requests_aws4auth/aws4auth.py
Expand Up @@ -615,7 +615,15 @@ def get_canonical_headers(cls, req, include=None):
# in the signed headers, but Requests doesn't include it in a
# PreparedRequest
if 'host' not in headers:
headers['host'] = urlparse(str(req.url)).netloc.split(':')[0]
purl = urlparse(str(req.url))
netloc = purl.netloc
# Python's http client only includes the port if it is non-default,
# see http.client.HTTPConnection.putrequest. The request URL, on the
# other hand, might explicitly include it.
if ((purl.port == 80 and purl.scheme == 'http')
or (purl.port == 443 and purl.scheme == 'https')):
netloc = netloc.rsplit(":", 1)[0]
headers['host'] = netloc
# Aggregate for upper/lowercase header name collisions in header names,
# AMZ requires values of colliding headers be concatenated into a
# single header with lowercase name. Although this is not possible with
Expand Down
2 changes: 1 addition & 1 deletion requests_aws4auth/test/test_requests_aws4auth.py
Expand Up @@ -949,7 +949,7 @@ def test_netloc_port(self):
request.
"""
req = requests.Request('GET', 'http://amazonaws.com:8443')
req = requests.Request('GET', 'https://amazonaws.com:443')
preq = req.prepare()
self.assertNotIn('host', preq.headers)
result = AWS4Auth.get_canonical_headers(preq, include=['host'])
Expand Down

0 comments on commit 8390d23

Please sign in to comment.