Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reuse the digest auth state to avoid unnecessary requests #2463

Merged
merged 9 commits into from Nov 29, 2022
21 changes: 16 additions & 5 deletions httpx/_auth.py
Expand Up @@ -155,8 +155,15 @@ def __init__(
) -> None:
self._username = to_bytes(username)
self._password = to_bytes(password)
self._last_challenge: typing.Optional[_DigestAuthChallenge] = None
self._nonce_count = 1

def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
if self._last_challenge:
request.headers["Authorization"] = self._build_auth_header(
request, self._last_challenge
)

response = yield request

if response.status_code != 401 or "www-authenticate" not in response.headers:
Expand All @@ -172,8 +179,12 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non
# header, then we don't need to build an authenticated request.
return

challenge = self._parse_challenge(request, response, auth_header)
request.headers["Authorization"] = self._build_auth_header(request, challenge)
self._last_challenge = self._parse_challenge(request, response, auth_header)
self._nonce_count = 1

request.headers["Authorization"] = self._build_auth_header(
request, self._last_challenge
)
yield request

def _parse_challenge(
Expand Down Expand Up @@ -222,9 +233,9 @@ def digest(data: bytes) -> bytes:
# TODO: implement auth-int
HA2 = digest(A2)

nonce_count = 1 # TODO: implement nonce counting
nc_value = b"%08x" % nonce_count
cnonce = self._get_client_nonce(nonce_count, challenge.nonce)
nc_value = b"%08x" % self._nonce_count
cnonce = self._get_client_nonce(self._nonce_count, challenge.nonce)
self._nonce_count += 1

HA1 = digest(A1)
if challenge.algorithm.lower().endswith("-sess"):
Expand Down
109 changes: 83 additions & 26 deletions tests/client/test_auth.py
Expand Up @@ -8,6 +8,7 @@
import os
import threading
import typing
from urllib.request import parse_keqv_list

import pytest

Expand Down Expand Up @@ -151,14 +152,14 @@ async def async_auth_flow(
@pytest.mark.asyncio
async def test_basic_auth() -> None:
url = "https://example.org/"
auth = ("tomchristie", "password123")
auth = ("user", "password123")
app = App()

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
response = await client.get(url, auth=auth)

assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}


@pytest.mark.asyncio
Expand All @@ -167,7 +168,7 @@ async def test_basic_auth_with_stream() -> None:
See: https://github.com/encode/httpx/pull/1312
"""
url = "https://example.org/"
auth = ("tomchristie", "password123")
auth = ("user", "password123")
app = App()

async with httpx.AsyncClient(
Expand All @@ -177,25 +178,25 @@ async def test_basic_auth_with_stream() -> None:
await response.aread()

assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}


@pytest.mark.asyncio
async def test_basic_auth_in_url() -> None:
url = "https://tomchristie:password123@example.org/"
url = "https://user:password123@example.org/"
app = App()

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}


@pytest.mark.asyncio
async def test_basic_auth_on_session() -> None:
url = "https://example.org/"
auth = ("tomchristie", "password123")
auth = ("user", "password123")
app = App()

async with httpx.AsyncClient(
Expand All @@ -204,7 +205,7 @@ async def test_basic_auth_on_session() -> None:
response = await client.get(url)

assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}


@pytest.mark.asyncio
Expand Down Expand Up @@ -279,7 +280,7 @@ async def test_trust_env_auth() -> None:
@pytest.mark.asyncio
async def test_auth_disable_per_request() -> None:
url = "https://example.org/"
auth = ("tomchristie", "password123")
auth = ("user", "password123")
app = App()

async with httpx.AsyncClient(
Expand Down Expand Up @@ -317,13 +318,13 @@ async def test_auth_property() -> None:
async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
assert client.auth is None

client.auth = ("tomchristie", "password123") # type: ignore
client.auth = ("user", "password123") # type: ignore
assert isinstance(client.auth, BasicAuth)

url = "https://example.org/"
response = await client.get(url)
assert response.status_code == 200
assert response.json() == {"auth": "Basic dG9tY2hyaXN0aWU6cGFzc3dvcmQxMjM="}
assert response.json() == {"auth": "Basic dXNlcjpwYXNzd29yZDEyMw=="}


@pytest.mark.asyncio
Expand All @@ -347,7 +348,7 @@ async def test_auth_invalid_type() -> None:
@pytest.mark.asyncio
async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = App()

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -360,7 +361,7 @@ async def test_digest_auth_returns_no_auth_if_no_digest_header_in_response() ->

def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
auth_header = "Token ..."
app = App(auth_header=auth_header, status_code=401)

Expand All @@ -375,7 +376,7 @@ def test_digest_auth_returns_no_auth_if_alternate_auth_scheme() -> None:
@pytest.mark.asyncio
async def test_digest_auth_200_response_including_digest_auth_header() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
auth_header = 'Digest realm="realm@host.com",qop="auth",nonce="abc",opaque="xyz"'
app = App(auth_header=auth_header, status_code=200)

Expand All @@ -390,7 +391,7 @@ async def test_digest_auth_200_response_including_digest_auth_header() -> None:
@pytest.mark.asyncio
async def test_digest_auth_401_response_without_digest_auth_header() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = App(auth_header="", status_code=401)

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand Down Expand Up @@ -419,7 +420,7 @@ async def test_digest_auth(
algorithm: str, expected_hash_length: int, expected_response_length: int
) -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = DigestApp(algorithm=algorithm)

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -435,7 +436,7 @@ async def test_digest_auth(
response_fields = [field.strip() for field in fields.split(",")]
digest_data = dict(field.split("=") for field in response_fields)

assert digest_data["username"] == '"tomchristie"'
assert digest_data["username"] == '"user"'
assert digest_data["realm"] == '"httpx@example.org"'
assert "nonce" in digest_data
assert digest_data["uri"] == '"/"'
Expand All @@ -450,7 +451,7 @@ async def test_digest_auth(
@pytest.mark.asyncio
async def test_digest_auth_no_specified_qop() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = DigestApp(qop="")

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -469,7 +470,7 @@ async def test_digest_auth_no_specified_qop() -> None:
assert "qop" not in digest_data
assert "nc" not in digest_data
assert "cnonce" not in digest_data
assert digest_data["username"] == '"tomchristie"'
assert digest_data["username"] == '"user"'
assert digest_data["realm"] == '"httpx@example.org"'
assert len(digest_data["nonce"]) == 64 + 2 # extra quotes
assert digest_data["uri"] == '"/"'
Expand All @@ -482,7 +483,7 @@ async def test_digest_auth_no_specified_qop() -> None:
@pytest.mark.asyncio
async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str) -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = DigestApp(qop=qop)

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -495,7 +496,7 @@ async def test_digest_auth_qop_including_spaces_and_auth_returns_auth(qop: str)
@pytest.mark.asyncio
async def test_digest_auth_qop_auth_int_not_implemented() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = DigestApp(qop="auth-int")

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -506,7 +507,7 @@ async def test_digest_auth_qop_auth_int_not_implemented() -> None:
@pytest.mark.asyncio
async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = DigestApp(qop="not-auth")

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -517,7 +518,7 @@ async def test_digest_auth_qop_must_be_auth_or_auth_int() -> None:
@pytest.mark.asyncio
async def test_digest_auth_incorrect_credentials() -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = DigestApp(send_response_after_attempt=2)

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -527,6 +528,62 @@ async def test_digest_auth_incorrect_credentials() -> None:
assert len(response.history) == 1


@pytest.mark.asyncio
async def test_digest_auth_reuses_challenge() -> None:
url = "https://example.org/"
auth = DigestAuth(username="user", password="password123")
app = DigestApp()

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
response_1 = await client.get(url, auth=auth)
response_2 = await client.get(url, auth=auth)

assert response_1.status_code == 200
assert response_2.status_code == 200

assert len(response_1.history) == 1
assert len(response_2.history) == 0
rettier marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.asyncio
async def test_digest_auth_resets_nonce_count_after_401() -> None:
url = "https://example.org/"
auth = DigestAuth(username="user", password="password123")
app = DigestApp()

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
response_1 = await client.get(url, auth=auth)
assert response_1.status_code == 200
assert len(response_1.history) == 1

first_nonce = parse_keqv_list(
response_1.request.headers["Authorization"].split(", ")
)["nonce"]
first_nc = parse_keqv_list(
response_1.request.headers["Authorization"].split(", ")
)["nc"]

# with this we now force a 401 on a subsequent (but initial) request
app.send_response_after_attempt = 2
rettier marked this conversation as resolved.
Show resolved Hide resolved

# we expect the client again to try to authenticate, i.e. the history length must be 1
response_2 = await client.get(url, auth=auth)
assert response_2.status_code == 200
assert len(response_2.history) == 1

second_nonce = parse_keqv_list(
response_2.request.headers["Authorization"].split(", ")
)["nonce"]
second_nc = parse_keqv_list(
response_2.request.headers["Authorization"].split(", ")
)["nc"]

assert first_nonce != second_nonce # ensures that the auth challenge was reset
assert (
first_nc == second_nc
) # ensures the nonce count is reset when the authentication failed


@pytest.mark.parametrize(
"auth_header",
[
Expand All @@ -539,7 +596,7 @@ async def test_async_digest_auth_raises_protocol_error_on_malformed_header(
auth_header: str,
) -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = App(auth_header=auth_header, status_code=401)

async with httpx.AsyncClient(transport=httpx.MockTransport(app)) as client:
Expand All @@ -558,7 +615,7 @@ def test_sync_digest_auth_raises_protocol_error_on_malformed_header(
auth_header: str,
) -> None:
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = App(auth_header=auth_header, status_code=401)

with httpx.Client(transport=httpx.MockTransport(app)) as client:
Expand Down Expand Up @@ -629,7 +686,7 @@ async def handle_async_request(self, request: Request) -> Response:
@pytest.mark.asyncio
async def test_digest_auth_unavailable_streaming_body():
url = "https://example.org/"
auth = DigestAuth(username="tomchristie", password="password123")
auth = DigestAuth(username="user", password="password123")
app = DigestApp()

async def streaming_body():
Expand Down
40 changes: 40 additions & 0 deletions tests/test_auth.py
Expand Up @@ -3,6 +3,8 @@

Integration tests also exist in tests/client/test_auth.py
"""
from urllib.request import parse_keqv_list

import pytest

import httpx
Expand Down Expand Up @@ -61,3 +63,41 @@ def test_digest_auth_with_401():
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)


def test_digest_auth_with_401_nonce_counting():
rettier marked this conversation as resolved.
Show resolved Hide resolved
auth = httpx.DigestAuth(username="user", password="pass")
request = httpx.Request("GET", "https://www.example.com")

# The initial request should not include an auth header.
flow = auth.sync_auth_flow(request)
request = next(flow)
assert "Authorization" not in request.headers

# If a 401 response is returned, then a digest auth request is made.
headers = {
"WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."'
}
response = httpx.Response(
content=b"Auth required", status_code=401, headers=headers
)
first_request = flow.send(response)
assert first_request.headers["Authorization"].startswith("Digest")

# Each subsequent request contains the digest header by default...
request = httpx.Request("GET", "https://www.example.com")
flow = auth.sync_auth_flow(request)
second_request = next(flow)
assert second_request.headers["Authorization"].startswith("Digest")

# ... and the client nonce count (nc) is increased
first_nc = parse_keqv_list(first_request.headers["Authorization"].split(", "))["nc"]
second_nc = parse_keqv_list(second_request.headers["Authorization"].split(", "))[
"nc"
]
assert int(first_nc, 16) + 1 == int(second_nc, 16)

# No other requests are made.
response = httpx.Response(content=b"Hello, world!", status_code=200)
with pytest.raises(StopIteration):
flow.send(response)