diff --git a/starlette/testclient.py b/starlette/testclient.py index 8de1e3e31..48afcb673 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -368,6 +368,7 @@ def __init__( backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, cookies: httpx._client.CookieTypes = None, + headers: typing.Dict[str, str] = None, ) -> None: self.async_backend = _AsyncBackend( backend=backend, backend_options=backend_options or {} @@ -385,10 +386,13 @@ def __init__( raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) + if headers is None: + headers = {} + headers.setdefault("user-agent", "testclient") super().__init__( app=self.app, base_url=base_url, - headers={"user-agent": "testclient"}, + headers=headers, transport=transport, follow_redirects=True, cookies=cookies, diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 84fcfbafd..a8715ea2d 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -11,6 +11,7 @@ from starlette.middleware import Middleware from starlette.responses import JSONResponse, Response from starlette.routing import Route +from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -67,6 +68,25 @@ def homepage(request): assert response.json() == {"mock": "example"} +def test_testclient_headers_behavior(): + """ + We should be able to use the test client with user defined headers. + + This is useful if we need to set custom headers for authentication + during tests or in development. + """ + + client = TestClient(mock_service) + assert client.headers.get("user-agent") == "testclient" + + client = TestClient(mock_service, headers={"user-agent": "non-default-agent"}) + assert client.headers.get("user-agent") == "non-default-agent" + + client = TestClient(mock_service, headers={"Authentication": "Bearer 123"}) + assert client.headers.get("user-agent") == "testclient" + assert client.headers.get("Authentication") == "Bearer 123" + + def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name): """ This test asserts a number of properties that are important for an