diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 98445de2a..58c81c9e7 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -466,3 +466,34 @@ async def get_subprotocol(url): accepted_subprotocol = loop.run_until_complete(get_subprotocol(url)) assert accepted_subprotocol == subprotocol loop.close() + + +@pytest.mark.parametrize("protocol_cls", WS_PROTOCOLS) +def test_server_lost_connection(protocol_cls): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + await self.receive() + + # Simulate a lost connection + # without receiving a close frame + self.send.__self__.connection_lost(None) + + with pytest.raises(Exception) as exc: + await self.send({"type": "websocket.send", "text": "123"}) + + assert exc.value.code == 1006 + + async def websocket_session(url): + async with websockets.connect(url) as websocket: + await websocket.ping() + await websocket.send("abc") + # Delay exiting context manager + # to avoid sending a close frame before server attempts send + await asyncio.sleep(1) + + if protocol_cls is WebSocketProtocol: + with run_server(App, protocol_cls=protocol_cls) as url: + loop = asyncio.new_event_loop() + loop.run_until_complete(websocket_session(url)) + loop.close()