Skip to content

Commit

Permalink
Close #148: Execute server function after receiving initial message
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Jun 13, 2022
1 parent 841c21d commit b9f8339
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
25 changes: 23 additions & 2 deletions shiny/_connection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Optional
from typing import Callable, Optional, TYPE_CHECKING, cast

if TYPE_CHECKING:
from shiny import Session, Inputs, Outputs

import starlette.websockets
from starlette.websockets import WebSocketState
from starlette.requests import HTTPConnection

from . import _utils


class Connection(ABC):
"""Abstract class to serve a session and send/receive messages to the
Expand All @@ -29,12 +34,28 @@ def get_http_conn(self) -> HTTPConnection:


class MockConnection(Connection):
def __init__(self):
def __init__(
self, server: Optional[Callable[["Inputs", "Outputs", "Session"], None]] = None
) -> None:
# This currently hard-codes some basic values for scope. In the future, we could
# make those more configurable if we need to customize the HTTPConnection (like
# "scheme", "path", and "query_string").
self._http_conn = HTTPConnection(scope={"type": "websocket", "headers": {}})

self._on_ended_callbacks = _utils.Callbacks()

if server is not None:
from .session import Inputs, Outputs, Session, session_context

self = cast(Session, self)
self.input = Inputs()
self.output = Outputs(self)
with session_context(self):
server(self.input, self.output, self)

def on_ended(self, fn: Callable[[], None]) -> Callable[[], None]:
return self._on_ended_callbacks.register(fn)

async def send(self, message: str) -> None:
pass

Expand Down
5 changes: 2 additions & 3 deletions shiny/session/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ def __init__(
self._flush_callbacks = _utils.Callbacks()
self._flushed_callbacks = _utils.Callbacks()

with session_context(self):
self.app.server(self.input, self.output, self)

def _register_session_end_callbacks(self) -> None:
# This is to be called from the initialization. It registers functions
# that are called when a session ends.
Expand Down Expand Up @@ -229,6 +226,8 @@ async def _run(self) -> None:
if message_obj["method"] == "init":
message_obj = typing.cast(ClientMessageInit, message_obj)
self._manage_inputs(message_obj["data"])
with session_context(self):
self.app.server(self.input, self.output, self)

elif message_obj["method"] == "update":
message_obj = typing.cast(ClientMessageUpdate, message_obj)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _():
sessions["top_current"] = get_current_session()
sessions["top_calc_current"] = out()

App(ui.TagList(), server)._create_session(MockConnection())
MockConnection(server)
run_coro_sync(reactive.flush())

assert sessions["inner"] is sessions["inner_current"]
Expand All @@ -141,5 +141,5 @@ def _():

assert sessions["top"] is sessions["top_current"]
assert sessions["top_current"] is sessions["top_calc_current"]
assert isinstance(sessions["top_current"], Session)
assert isinstance(sessions["top_current"], MockConnection)
assert not isinstance(sessions["top_current"], ModuleSession)

0 comments on commit b9f8339

Please sign in to comment.