Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default headers to WebSockets implementations #1606

Merged
merged 69 commits into from Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
d8a87a5
added logic to accept default headers in websocket
iudeen Jul 15, 2022
170f16a
added default_headers to class init
iudeen Jul 15, 2022
bd28f39
fix
Jul 15, 2022
83a29e8
docs: added limitation about Websockets
Jul 15, 2022
eabe380
docs: added limitation about Websockets
Jul 15, 2022
5ecfd07
Update docs/settings.md
iudeen Jul 22, 2022
d954ef1
docs: moved websocket limitation
Jul 22, 2022
c7546bd
Update docs/settings.md
Kludex Jul 22, 2022
eef254d
Merge remote-tracking branch 'origin/master' into patch/websocket-ser…
Jul 25, 2022
f7fc772
feat(websockets): added server header support for websockets
Jul 25, 2022
87408cf
Merge branch 'encode:master' into master
iudeen Jul 25, 2022
62c6c7e
feat(websockets-wsproto): added server header support for websockets
Jul 25, 2022
4928484
added logic to accept default headers in websocket (#1)
iudeen Jul 25, 2022
d4c55c7
style(websockets): formatting
Jul 25, 2022
8e44b69
Merge branch 'patch/websocket-server-name'
Jul 25, 2022
4c860f7
style(websockets): formatting
Jul 25, 2022
728b4b9
style(websockets): formatting
Jul 25, 2022
2beab89
Merge branch 'encode:master' into master
iudeen Aug 12, 2022
7a8aa99
revert(websockets): removed work-around logic to add headers
Aug 12, 2022
94ddfe5
Revert "feat(websockets-wsproto): added server header support for web…
Aug 12, 2022
49b135b
feat(wsproto): Add headers support for wsproto implementation
Aug 12, 2022
f4ada86
docs: added limitation about Websockets - no-date-header
Aug 12, 2022
7a08ff0
docs: added limitation about Websockets - no-date-header
Aug 12, 2022
137fd38
Update uvicorn/protocols/websockets/wsproto_impl.py
iudeen Aug 15, 2022
e241caa
revert: removed default header logic
Aug 15, 2022
eede8f4
test: added test for wsproto implementation
Aug 15, 2022
62cbf96
style: fixed linting
Aug 15, 2022
877a9fc
test: added test for multiple server headers
Aug 16, 2022
1cbf30c
Feature/new (#2)
iudeen Aug 20, 2022
b061ede
flake-8 corrections
iudeen Aug 20, 2022
02c96e4
run black
iudeen Aug 20, 2022
3c6c915
Added support for no-server-header in Websocket Protocols
iudeen Aug 21, 2022
9b79b31
ran black
iudeen Aug 21, 2022
1548e92
update requirements.txt
iudeen Aug 21, 2022
2b7267f
Revert "update requirements.txt"
iudeen Aug 21, 2022
402be7b
update requirements to use wip main branch from Websockets
iudeen Aug 21, 2022
85c342b
update requirements to use wip main branch from Websockets
iudeen Aug 21, 2022
60f62c0
Merge remote-tracking branch 'upstream/master' into feature/websocket…
iudeen Aug 24, 2022
f290fe7
Merge remote-tracking branch 'upstream/master' into feature/websocket…
iudeen Aug 24, 2022
dc8a505
update requirements to use wip main branch from Websockets
iudeen Aug 24, 2022
019e7d2
wip: add allow-direct-references to support installation from Git
iudeen Aug 24, 2022
59cf6f4
update requirements to use wip main branch from Websockets
iudeen Aug 24, 2022
866f3e9
update requirements to use wip main branch from Websockets
iudeen Aug 24, 2022
f6a0872
update requirements to use wip main branch from Websockets
iudeen Aug 24, 2022
e31cd53
Merge branch 'encode:master' into feature/websocket-headers
iudeen Aug 28, 2022
9aa02c3
Bump wsproto from 1.1.0 to 1.2.0
iudeen Sep 4, 2022
93af5ac
docs: amend settings.md
iudeen Sep 10, 2022
a86da4b
Merge remote-tracking branch 'upstream/master' into feature/websocket…
iudeen Sep 10, 2022
7daff74
Merge remote-tracking branch 'upstream/master' into feature/websocket…
iudeen Sep 11, 2022
97ca634
style: ran isort
iudeen Sep 11, 2022
666a8b7
fix: removed `_added_names` logic to support arbitrary extra headers …
iudeen Sep 14, 2022
437c066
fix: removed `_added_names` logic to support arbitrary extra headers …
iudeen Sep 14, 2022
27bd66f
fix: removed `_added_names` logic to support arbitrary extra headers …
iudeen Sep 14, 2022
50d2e9f
fix: removed `_added_names` logic to support arbitrary extra headers …
iudeen Sep 14, 2022
37dcbd5
Merge branch 'encode:master' into feature/websocket-headers
iudeen Sep 27, 2022
adddfc3
Merge remote-tracking branch 'upstream/master' into feature/websocket…
iudeen Oct 26, 2022
52d871b
chore: remove git link to websockets library
iudeen Oct 26, 2022
0be0e7c
chore: bump websockets requirement to >=10.4
iudeen Oct 26, 2022
5673fa3
refactor: USER_AGENT is deprecated. Make changes to accommodate
iudeen Oct 26, 2022
b51c8c2
style: run black
iudeen Oct 26, 2022
c05759a
tests: added test (improve coverage)
iudeen Oct 26, 2022
b19f58a
tests: added test (improve coverage)
iudeen Oct 26, 2022
f03c77a
Merge branch 'encode:master' into feature/websocket-headers
iudeen Oct 27, 2022
09bcd73
tests: added test (improve coverage) for WSProto
iudeen Oct 27, 2022
321c8e6
Refactor feature/websocket-headers
Kludex Oct 28, 2022
da1f87f
remove unused imports
iudeen Oct 28, 2022
8b545ad
add clarity in docs
iudeen Oct 28, 2022
afae8ed
remove repeated test
iudeen Oct 28, 2022
1768b31
remove white space
iudeen Oct 28, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/settings.md
Expand Up @@ -90,7 +90,7 @@ connecting IPs in the `forwarded-allow-ips` configuration.
* `--date-header` / `--no-date-header` - Enable/Disable default `Date` header.

!!! note
The `--no-server-header` flag doesn't have effect on the WebSockets implementations.
The `--no-date-header` flag doesn't have effect on the WebSockets implementations.
iudeen marked this conversation as resolved.
Show resolved Hide resolved

## HTTPS

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -42,7 +42,7 @@ standard = [
"PyYAML>=5.1",
"uvloop>=0.14.0,!=0.15.0,!=0.15.1; sys_platform != 'win32' and (sys_platform != 'cygwin' and platform_python_implementation != 'PyPy')",
"watchfiles>=0.13",
"websockets>=10.0",
"websockets>=10.4",
]

[project.scripts]
Expand Down
198 changes: 197 additions & 1 deletion tests/protocols/test_websocket.py
Expand Up @@ -5,11 +5,13 @@

from tests.protocols.test_http import HTTP_PROTOCOLS
from tests.utils import run_server
from uvicorn import Server
from uvicorn.config import Config
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol

try:
import websockets
import websockets.exceptions
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory

from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
Expand All @@ -18,9 +20,9 @@
WebSocketProtocol = None
ClientPerMessageDeflateFactory = None


ONLY_WEBSOCKETPROTOCOL = [p for p in [WebSocketProtocol] if p is not None]
WS_PROTOCOLS = [p for p in [WSProtocol, WebSocketProtocol] if p is not None]
ONLY_WS_PROTOCOL = [p for p in [WSProtocol] if p is not None]
pytestmark = pytest.mark.skipif(
websockets is None, reason="This test needs the websockets module"
)
Expand Down Expand Up @@ -658,3 +660,197 @@ async def send_text(url):
await send_text("ws://127.0.0.1:8000")

assert frames == [b"abc", b"abc", b"abc"]


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_default_server_headers(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert headers.get("server") == "uvicorn" and "date" in headers


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_no_server_headers(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
server_header=False,
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert "server" not in headers


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", ONLY_WS_PROTOCOL)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_no_date_header(ws_protocol_cls, http_protocol_cls):
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not 100% true. I forgot to mention... The Date headers from the websockets cannot be removed... 🤔

Should we do something about it on the tests? It's on purpose... 🤔

It's a question, I'm still thinking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do what on tests? There is no way to remove date header in websockets I think.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, I know, but the tests make us believe that we forgot about Date headers... 🤔

class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
date_header=False,
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert "date" not in headers


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", ONLY_WS_PROTOCOL)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_multiple_server_header_in_ws(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send(
{
"type": "websocket.accept",
"headers": [
(b"Server", b"over-ridden"),
(b"Server", b"another-value"),
],
}
)

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert all(
x in headers.get_all("Server") for x in ["over-ridden", "another-value"]
)


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_multiple_arbitrary_headers_with_same_name(
ws_protocol_cls, http_protocol_cls
):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send(
{
"type": "websocket.accept",
"headers": [
(b"Potato", b"cool-potato"),
(b"Potato", b"super-cool-potato"),
],
}
)

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert all(
x in headers.get_all("Potato") for x in ["cool-potato", "super-cool-potato"]
)


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", ONLY_WEBSOCKETPROTOCOL)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_multiple_server_header_in_websockets(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send(
{
"type": "websocket.accept",
"headers": [
(b"Server", b"over-ridden"),
(b"Server", b"another-value"),
],
iudeen marked this conversation as resolved.
Show resolved Hide resolved
}
)

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert len(headers.get_all("Server")) == 1
assert headers.get("Server") == "uvicorn"


@pytest.mark.anyio
iudeen marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("ws_protocol_cls", ONLY_WEBSOCKETPROTOCOL)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_shutdown_when_connection_active_in_websockets(
ws_protocol_cls, http_protocol_cls
):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
)
server = Server(config=config)
cancel_handle = asyncio.ensure_future(server.serve(sockets=None))
iudeen marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(0.1)
async with websockets.connect("ws://127.0.0.1:8000"):
ws_conn = list(server.server_state.connections)[0]
ws_conn.shutdown()
assert ws_conn.ws_server.closing is True
assert ws_conn.transport.is_closing()
iudeen marked this conversation as resolved.
Show resolved Hide resolved
await server.shutdown()
cancel_handle.cancel()
9 changes: 8 additions & 1 deletion uvicorn/protocols/utils.py
@@ -1,6 +1,6 @@
import asyncio
import urllib.parse
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple

if TYPE_CHECKING:
from asgiref.typing import WWWScope
Expand Down Expand Up @@ -53,3 +53,10 @@ def get_path_with_query_string(scope: "WWWScope") -> str:
path_with_query_string, scope["query_string"].decode("ascii")
)
return path_with_query_string


def get_server_header(default_headers: List[Tuple[bytes, bytes]], override: str) -> str:
return next(
(i for i in default_headers if i[0] in (b"Server", b"server")),
[override.encode()],
)[-1].decode()
28 changes: 21 additions & 7 deletions uvicorn/protocols/websockets/websockets_impl.py
Expand Up @@ -19,12 +19,13 @@
get_local_addr,
get_path_with_query_string,
get_remote_addr,
get_server_header,
is_ssl,
)
from uvicorn.server import ServerState

if sys.version_info < (3, 8):
from typing_extensions import Literal
from typing_extensions import Literal # pragma: no cover
iudeen marked this conversation as resolved.
Show resolved Hide resolved
else:
from typing import Literal

Expand Down Expand Up @@ -70,10 +71,17 @@ def __init__(
self.app = config.loaded_app
self.loop = _loop or asyncio.get_event_loop()
self.root_path = config.root_path
if self.config.server_header:
self.server_header = get_server_header(
default_headers=server_state.default_headers, override="uvicorn"
)
else:
self.server_header = None

# Shared server state
self.connections = server_state.connections
self.tasks = server_state.tasks
self.default_headers = server_state.default_headers

# Connection state
self.transport: asyncio.Transport = None # type: ignore[assignment]
Expand Down Expand Up @@ -103,6 +111,7 @@ def __init__(
max_size=self.config.ws_max_size,
ping_interval=self.config.ws_ping_interval,
ping_timeout=self.config.ws_ping_timeout,
server_header=self.server_header,
extensions=extensions,
logger=logging.getLogger("uvicorn.error"),
extra_headers=[],
Expand Down Expand Up @@ -265,12 +274,17 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
self.accepted_subprotocol = cast(
Optional[Subprotocol], message.get("subprotocol")
)
if "headers" in message:
self.extra_headers.extend(
# ASGI spec requires bytes
# But for compatibility we need to convert it to strings
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in message["headers"]
headers = list(message.get("headers", []))
for name, value in headers:
if name.lower() in [b"server"]:
continue
iudeen marked this conversation as resolved.
Show resolved Hide resolved
self.extra_headers.append(
(
# ASGI spec requires bytes
# But for compatibility we need to convert it to strings
name.decode("latin-1"),
value.decode("latin-1"),
)
)
self.handshake_started_event.set()

Expand Down
3 changes: 2 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Expand Up @@ -32,6 +32,7 @@ def __init__(self, config, server_state, _loop=None):
# Shared server state
self.connections = server_state.connections
self.tasks = server_state.tasks
self.default_headers = server_state.default_headers

# Connection state
self.transport = None
Expand Down Expand Up @@ -253,7 +254,7 @@ async def send(self, message):
)
self.handshake_complete = True
subprotocol = message.get("subprotocol")
extra_headers = message.get("headers", [])
extra_headers = list(message.get("headers", [])) + self.default_headers
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
Expand Down