Skip to content

Commit

Permalink
Add default headers to WebSockets implementations (#1606)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Irfanuddin <irfanuddin@knowledgelens.com>
Co-authored-by: Irfanuddin <irfanuddinshafi@gmail.com>
  • Loading branch information
4 people committed Oct 29, 2022
1 parent 3c21f63 commit ff6d50f
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/settings.md
Expand Up @@ -90,7 +90,7 @@ connecting IPs in the `forwarded-allow-ips` configuration.
* `--date-header` / `--no-date-header` - Enable/Disable default `Date` header.

!!! note
The `--no-server-header` flag doesn't have effect on the WebSockets implementations.
The `--no-date-header` flag doesn't have effect on the `websockets` implementation.

## HTTPS

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -42,7 +42,7 @@ standard = [
"PyYAML>=5.1",
"uvloop>=0.14.0,!=0.15.0,!=0.15.1; sys_platform != 'win32' and (sys_platform != 'cygwin' and platform_python_implementation != 'PyPy')",
"watchfiles>=0.13",
"websockets>=10.0",
"websockets>=10.4",
]

[project.scripts]
Expand Down
100 changes: 99 additions & 1 deletion tests/protocols/test_websocket.py
Expand Up @@ -10,6 +10,7 @@

try:
import websockets
import websockets.exceptions
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory

from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
Expand All @@ -18,8 +19,8 @@
WebSocketProtocol = None
ClientPerMessageDeflateFactory = None


ONLY_WEBSOCKETPROTOCOL = [p for p in [WebSocketProtocol] if p is not None]
ONLY_WS_PROTOCOL = [p for p in [WSProtocol] if p is not None]
WS_PROTOCOLS = [p for p in [WSProtocol, WebSocketProtocol] if p is not None]
pytestmark = pytest.mark.skipif(
websockets is None, reason="This test needs the websockets module"
Expand Down Expand Up @@ -658,3 +659,100 @@ async def send_text(url):
await send_text("ws://127.0.0.1:8000")

assert frames == [b"abc", b"abc", b"abc"]


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_default_server_headers(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off")
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert headers.get("server") == "uvicorn" and "date" in headers


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_no_server_headers(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
server_header=False,
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert "server" not in headers


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", ONLY_WS_PROTOCOL)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_no_date_header(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
date_header=False,
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert "date" not in headers


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_multiple_server_header(ws_protocol_cls, http_protocol_cls):
class App(WebSocketResponse):
async def websocket_connect(self, message):
await self.send(
{
"type": "websocket.accept",
"headers": [
(b"Server", b"over-ridden"),
(b"Server", b"another-value"),
],
}
)

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.response_headers

config = Config(
app=App,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
)
async with run_server(config):
headers = await open_connection("ws://127.0.0.1:8000")
assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"]
10 changes: 7 additions & 3 deletions uvicorn/protocols/websockets/websockets_impl.py
Expand Up @@ -23,9 +23,9 @@
)
from uvicorn.server import ServerState

if sys.version_info < (3, 8):
if sys.version_info < (3, 8): # pragma: py-gte-38
from typing_extensions import Literal
else:
else: # pragma: py-lt-38
from typing import Literal

if TYPE_CHECKING:
Expand Down Expand Up @@ -103,9 +103,13 @@ def __init__(
max_size=self.config.ws_max_size,
ping_interval=self.config.ws_ping_interval,
ping_timeout=self.config.ws_ping_timeout,
server_header=None,
extensions=extensions,
logger=logging.getLogger("uvicorn.error"),
extra_headers=[],
extra_headers=[
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in server_state.default_headers
],
)

def connection_made( # type: ignore[override]
Expand Down
3 changes: 2 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Expand Up @@ -32,6 +32,7 @@ def __init__(self, config, server_state, _loop=None):
# Shared server state
self.connections = server_state.connections
self.tasks = server_state.tasks
self.default_headers = server_state.default_headers

# Connection state
self.transport = None
Expand Down Expand Up @@ -255,7 +256,7 @@ async def send(self, message):
)
self.handshake_complete = True
subprotocol = message.get("subprotocol")
extra_headers = message.get("headers", [])
extra_headers = self.default_headers + list(message.get("headers", []))
extensions = []
if self.config.ws_per_message_deflate:
extensions.append(PerMessageDeflate())
Expand Down

0 comments on commit ff6d50f

Please sign in to comment.