Skip to content

Commit

Permalink
Invoke server func after init values received
Browse files Browse the repository at this point in the history
  • Loading branch information
jcheng5 committed Jun 24, 2022
1 parent 70885ae commit 8f4ec1f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 31 deletions.
20 changes: 13 additions & 7 deletions shiny/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,31 @@ def __init__(self):
# make those more configurable if we need to customize the HTTPConnection (like
# "scheme", "path", and "query_string").
self._http_conn = HTTPConnection(scope={"type": "websocket", "headers": {}})
self._queue: asyncio.Queue[str] = asyncio.Queue()

async def send(self, message: str) -> None:
pass

# I should say I’m not 100% that the receive method can be a no-op for our testing
# purposes. It might need to be asyncio.sleep(0), and/or it might need an external
# way to yield until we tell the connection to continue, so that the run loop can
# continue.
async def receive(self) -> str:
# Sleep forever
await asyncio.Event().wait()
raise RuntimeError("make the type checker happy")
msg = await self._queue.get()
if msg == "":
raise ConnectionClosed()
return msg

async def close(self, code: int, reason: Optional[str]) -> None:
pass

def get_http_conn(self) -> HTTPConnection:
return self._http_conn

def cause_receive(self, message: str) -> None:
"""Call from tests to simulate the other side sending a message"""
self._queue.put_nowait(message)

def cause_disconnect(self) -> None:
"""Call from tests to simulate the other side disconnecting"""
self.cause_receive("")


class StarletteConnection(Connection):
def __init__(self, conn: starlette.websockets.WebSocket):
Expand Down
57 changes: 39 additions & 18 deletions shiny/session/_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ("Session", "Inputs", "Outputs")

import enum
import functools
import os
from pathlib import Path
Expand Down Expand Up @@ -58,6 +59,19 @@
from .. import render
from .. import _utils


class ConnectionState(enum.Enum):
Start = 0
Running = 1
Closed = 2


class ProtocolError(Exception):
def __init__(self, message: str = ""):
super(ProtocolError, self).__init__(message)
self.message = message


# This cast is necessary because if the type checker thinks that if
# "tag" isn't in `message`, then it's not a ClientMessage object.
# This will be fixable when TypedDict items can be marked as
Expand Down Expand Up @@ -186,9 +200,6 @@ def __init__(
self._flush_callbacks = _utils.Callbacks()
self._flushed_callbacks = _utils.Callbacks()

with session_context(self):
self.app.server(self.input, self.output, self)

def _register_session_end_callbacks(self) -> None:
# This is to be called from the initialization. It registers functions
# that are called when a session ends.
Expand All @@ -213,6 +224,12 @@ async def close(self, code: int = 1001) -> None:
self._run_session_end_tasks()

async def _run(self) -> None:
conn_state: ConnectionState = ConnectionState.Start

def verify_state(expected_state: ConnectionState) -> None:
if conn_state != expected_state:
raise ProtocolError("Invalid method for the current session state")

await self._send_message(
{"config": {"workerId": "", "sessionId": str(self.id), "user": None}}
)
Expand All @@ -228,8 +245,8 @@ async def _run(self) -> None:
message, object_hook=_utils.lists_to_tuples
)
except json.JSONDecodeError:
print("ERROR: Invalid JSON message")
continue
warnings.warn("ERROR: Invalid JSON message")
return

if "method" not in message_obj:
self._send_error_response("Message does not contain 'method'.")
Expand All @@ -238,36 +255,40 @@ async def _run(self) -> None:
async with lock():

if message_obj["method"] == "init":
verify_state(ConnectionState.Start)

conn_state = ConnectionState.Running
message_obj = typing.cast(ClientMessageInit, message_obj)
self._manage_inputs(message_obj["data"])

with session_context(self):
self.app.server(self.input, self.output, self)

elif message_obj["method"] == "update":
verify_state(ConnectionState.Running)

message_obj = typing.cast(ClientMessageUpdate, message_obj)
self._manage_inputs(message_obj["data"])

else:
if "tag" not in message_obj:
warnings.warn(
"Cannot dispatch message with missing 'tag'; method: "
+ message_obj["method"]
)
return
if "args" not in message_obj:
warnings.warn(
"Cannot dispatch message with missing 'args'; method: "
+ message_obj["method"]
)
return
elif "tag" in message_obj and "args" in message_obj:
verify_state(ConnectionState.Running)

message_obj = typing.cast(ClientMessageOther, message_obj)
await self._dispatch(message_obj)

else:
raise ProtocolError(
f"Unrecognized method {message_obj['method']}"
)

self._request_flush()

await flush()

except ConnectionClosed:
self._run_session_end_tasks()
except ProtocolError as pe:
self._send_error_response(pe.message)

def _manage_inputs(self, data: Dict[str, object]) -> None:
for (key, val) in data.items():
Expand Down
21 changes: 15 additions & 6 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Tests for `Module`."""

import asyncio
from typing import Dict, Union, cast

import pytest

from htmltools import Tag, TagList
from shiny import *
from shiny.session import get_current_session
from shiny._connection import MockConnection
from shiny._namespaces import resolve_id
from shiny._utils import run_coro_sync
from htmltools import TagList, Tag
from shiny.session import get_current_session


@module.ui
Expand Down Expand Up @@ -37,7 +39,8 @@ def test_module_ui():
assert get_id(y, 2) == "outer-out2"


def test_session_scoping():
@pytest.mark.asyncio
async def test_session_scoping():

sessions: Dict[str, Union[Session, None, str]] = {}

Expand Down Expand Up @@ -85,8 +88,14 @@ def _():
sessions["top_id"] = session.ns("foo")
sessions["top_ui_id"] = get_id(mod_outer_ui("outer"), 0)

App(ui.TagList(), server)._create_session(MockConnection())
run_coro_sync(reactive.flush())
conn = MockConnection()
sess = App(ui.TagList(), server)._create_session(conn)

async def mock_client():
conn.cause_receive('{"method":"init","data":{}}')
conn.cause_disconnect()

await asyncio.gather(mock_client(), sess._run())

assert sessions["inner"] is sessions["inner_current"]
assert sessions["inner_current"] is sessions["inner_calc_current"]
Expand Down

0 comments on commit 8f4ec1f

Please sign in to comment.