Skip to content

Commit

Permalink
Move away from ModuleSession/ModuleInput/ModuleOutput toward SessionP…
Browse files Browse the repository at this point in the history
…roxy
  • Loading branch information
cpsievert committed May 20, 2022
1 parent 8acca98 commit c921b8f
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 257 deletions.
150 changes: 13 additions & 137 deletions shiny/_modules.py
@@ -1,90 +1,24 @@
__all__ = ("namespaced_id", "module_ui", "module_server")

from typing import Any, Callable, Optional, TypeVar, ParamSpec, Concatenate
import sys
from typing import Callable, TypeVar

from htmltools import TagChildArg
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec, Concatenate
else:
from typing import ParamSpec, Concatenate

from ._docstring import add_example
from ._namespaces import namespaced_id
from ._namespaces import namespaced_id, namespace_context, get_current_namespaces
from .session import Inputs, Outputs, Session, require_active_session, session_context


class ModuleInputs(Inputs):
"""
A class representing a module's outputs.
Warning
-------
An instance of this class is created for each request and passed as an argument to
the :class:`shiny.modules.Module`'s ``server`` function. For this reason, you
shouldn't need to create instances of this class yourself. Furthermore, you
probably shouldn't need this class for type checking either since it has the same
signature as :class:`shiny.session.Session`.
"""

def __init__(self, ns: str, parent_inputs: Inputs):
self._ns = namespaced_id(ns, parent_inputs._ns) # Support nested modules
# Don't set _parent attribute like the other classes since Inputs redefines
# __setattr__
self._map = parent_inputs._map


class ModuleOutputs(Outputs):
"""
A class representing a module's outputs.
Warning
-------
An instance of this class is created for each request and passed as an argument to
the :class:`shiny.modules.Module`'s ``server`` function. For this reason, you
shouldn't need to create instances of this class yourself. Furthermore, you
probably shouldn't need this class for type checking either since it has the same
signature as :class:`shiny.session.Session`.
"""

def __init__(self, ns: str, parent_outputs: Outputs):
self._ns = namespaced_id(ns, parent_outputs._ns) # Support nested modules
self._parent = parent_outputs

def __getattr__(self, attr: str) -> Any:
return getattr(self._parent, attr)


class ModuleSession(Session):
"""
A class representing a module's outputs.
Warning
-------
An instance of this class is created for each request and passed as an argument to
the :class:`shiny.modules.Module`'s ``server`` function. For this reason, you
shouldn't need to create instances of this class yourself. Furthermore, you
probably shouldn't need this class for type checking either since it has the same
signature as :class:`shiny.session.Session`.
"""

def __init__(self, ns: str, parent_session: Session):
self._ns: str = namespaced_id(ns, parent_session._ns) # Support nested modules
self._parent: Session = parent_session
self.input: ModuleInputs = ModuleInputs(ns, parent_session.input)
self.output: ModuleOutputs = ModuleOutputs(ns, parent_session.output)

def __getattr__(self, attr: str) -> Any:
return getattr(self._parent, attr)


class MockModuleSession(ModuleSession):
def __init__(self, ns: str):
self._ns = ns


P = ParamSpec("P")
R = TypeVar("R")


def module_ui(fn: Callable[P, R]) -> Callable[Concatenate[str, P], R]:
def wrapper(ns: str, *args: P.args, **kwargs: P.kwargs) -> R:
with session_context(MockModuleSession(ns)):
# TODO: what should happen if this is called *inside* of a session? Do we tack on the parent session's namespace as well?
with namespace_context(get_current_namespaces() + [ns]):
return fn(*args, **kwargs)

return wrapper
Expand All @@ -94,67 +28,9 @@ def module_server(
fn: Callable[Concatenate[Inputs, Outputs, Session, P], R]
) -> Callable[Concatenate[str, P], R]:
def wrapper(ns: str, *args: P.args, **kwargs: P.kwargs) -> R:
mod_sess = ModuleSession(ns, require_active_session(None))
with session_context(mod_sess):
return fn(mod_sess.input, mod_sess.output, mod_sess, *args, **kwargs)
sess = require_active_session(None)
child_sess = sess.make_scope(ns)
with session_context(child_sess):
return fn(child_sess.input, child_sess.output, child_sess, *args, **kwargs)

return wrapper


# @add_example()
# class Module:
# """
# Modularize UI and server-side logic.
#
# Parameters
# ----------
# ui
# The module's UI definition.
# server
# The module's server-side logic.
# """
#
# def __init__(
# self,
# ui: Callable[..., TagChildArg],
# server: Callable[[ModuleInputs, ModuleOutputs, ModuleSession], None],
# ) -> None:
# self._ui: Callable[..., TagChildArg] = ui
# self._server: Callable[
# [ModuleInputs, ModuleOutputs, ModuleSession], None
# ] = server
#
# def ui(self, ns: str, *args: Any) -> TagChildArg:
# """
# Render the module's UI.
#
# Parameters
# ----------
# namespace
# A namespace for the module.
# args
# Additional arguments to pass to the module's UI definition.
# """
#
# # Create a fake session so that namespaced_id() knows
# # what the relevant namespace is
# with session_context(MockModuleSession(ns)):
# return self._ui(*args)
#
# def server(self, ns: str, *, session: Optional[Session] = None) -> None:
# """
# Invoke the module's server-side logic.
#
# Parameters
# ----------
# ns
# A namespace for the module.
# session
# A :class:`~shiny.Session` instance. If not provided, it is inferred via
# :func:`~shiny.session.get_current_session`.
# """
#
# mod_sess = ModuleSession(ns, require_active_session(session))
# with session_context(mod_sess):
# return self._server(mod_sess.input, mod_sess.output, mod_sess)
#
51 changes: 28 additions & 23 deletions shiny/_namespaces.py
@@ -1,34 +1,39 @@
# TODO: make this available under the shiny.modules API
__all__ = ("namespaced_id",)
from contextlib import contextmanager
from contextvars import ContextVar, Token
from typing import Union, List

from typing import Union, Optional

from .types import MISSING, MISSING_TYPE
class ResolvedId(str):
pass


def namespaced_id(id: str, ns: Union[str, MISSING_TYPE, None] = MISSING) -> str:
"""
Namespace an ID based on the current ``Module()``'s namespace.
Id = Union[str, ResolvedId]

Parameters
----------
id
The ID to namespace..
"""
if isinstance(ns, MISSING_TYPE):
ns = get_current_namespace()

if ns is None:
def namespaced_id(id: str) -> str:
return namespaced_id_ns(id, get_current_namespaces())


def namespaced_id_ns(id: Id, namespaces: List[str] = []) -> str:
if isinstance(id, ResolvedId) or len(namespaces) == 0:
return id
else:
return ns + "_" + id
return ResolvedId("_".join(namespaces) + "_" + id)


def get_current_namespace() -> Optional[str]:
from .session import get_current_session
def get_current_namespaces() -> List[str]:
return _current_namespaces.get()

session = get_current_session()
if session is None:
return None
else:
return session._ns

_current_namespaces: ContextVar[List[str]] = ContextVar(
"current_namespaces", default=[]
)


@contextmanager
def namespace_context(namespaces: List[str]):
token: Token[List[str]] = _current_namespaces.set(namespaces)
try:
yield
finally:
_current_namespaces.reset(token)
91 changes: 70 additions & 21 deletions shiny/session/_session.py
Expand Up @@ -50,7 +50,7 @@
from .._fileupload import FileInfo, FileUploadManager
from ..http_staticfiles import FileResponse
from ..input_handler import input_handlers
from .._namespaces import namespaced_id
from .._namespaces import namespaced_id_ns
from ..reactive import Value, Effect, Effect_, isolate, flush
from ..reactive._core import lock
from ..types import SafeException, SilentCancelOutputException, SilentException
Expand Down Expand Up @@ -113,7 +113,14 @@ def empty_outbound_message_queues() -> OutBoundMessageQueues:
return {"values": [], "input_messages": [], "errors": []}


class Session:
# Makes isinstance(x, Session) also return True when x is a SessionProxy (i.e., a module
# session)
class SessionMeta(type):
def __instancecheck__(self, __instance: Any) -> bool:
return isinstance(__instance, SessionProxy)


class Session(object, metaclass=SessionMeta):
"""
A class representing a user session.
Expand Down Expand Up @@ -143,8 +150,6 @@ def __init__(
self.input: Inputs = Inputs()
self.output: Outputs = Outputs(self)

self._ns: Optional[str] = None # Only relevant for ModuleSession

self.user: Union[str, None] = None
self.groups: Union[List[str], None] = None
credentials_json: str = ""
Expand Down Expand Up @@ -479,7 +484,7 @@ def send_input_message(self, id: str, message: Dict[str, object]) -> None:
message
The message to send.
"""
msg: Dict[str, object] = {"id": namespaced_id(id, self._ns), "message": message}
msg: Dict[str, object] = {"id": id, "message": message}
self._outbound_message_queues["input_messages"].append(msg)
self._request_flush()

Expand Down Expand Up @@ -720,6 +725,45 @@ def _process_ui(self, ui: TagChildArg) -> RenderedDeps:

return {"deps": deps, "html": res["html"]}

@staticmethod
def ns(id: str) -> str:
return id

def make_scope(self, id: str) -> "Session":
ns = create_ns_func(id)
return SessionProxy(parent=self, ns=ns) # type: ignore


class SessionProxy:
def __init__(self, parent: Session, ns: Callable[[str], str]) -> None:
self._parent = parent
self.ns = ns
self.input = Inputs(values=parent.input._map, ns=ns)
self.output = Outputs(
session=cast(Session, self),
effects=self.output._effects,
suspend_when_hidden=self.output._suspend_when_hidden,
ns=ns,
)

def __getattr__(self, attr: str) -> Any:
return getattr(self._parent, attr)

def send_input_message(self, id: str, message: Dict[str, object]) -> None:
return self._parent.send_input_message(self.ns(id), message)

def download(
self, name: str, **kwargs: object
) -> Callable[[DownloadHandler], None]:
return self._parent.download(self.ns(name), **kwargs)

def make_scope(self, id: str) -> Session:
return self._parent.make_scope(self.ns(id))


def create_ns_func(namespace: str) -> Callable[[str], str]:
return lambda x: namespaced_id_ns(x, [namespace])


# ======================================================================================
# Inputs
Expand All @@ -739,31 +783,30 @@ class Inputs:
for type checking reasons).
"""

def __init__(self, **kwargs: object) -> None:
self._map: dict[str, Value[Any]] = {}
for key, value in kwargs.items():
self._map[key] = Value(value, read_only=True)

self._ns: Optional[str] = None # Only relevant for ModuleInputs()
def __init__(
self, values: Dict[str, Value[Any]] = {}, ns: Callable[[str], str] = lambda x: x
) -> None:
self._map = values
self._ns = ns

def __setitem__(self, key: str, value: Value[Any]) -> None:
if not isinstance(value, Value):
raise TypeError("`value` must be a reactive.Value object.")

self._map[namespaced_id(key, self._ns)] = value
self._map[self._ns(key)] = value

def __getitem__(self, key: str) -> Value[Any]:
key = namespaced_id(key, self._ns)
key = self._ns(key)
# Auto-populate key if accessed but not yet set. Needed to take reactive
# dependencies on input values that haven't been received from client
# yet.
if key not in self._map:
self._map[key] = Value(read_only=True)
self._map[key] = cast(Value[Any], Value(read_only=True))

return self._map[key]

def __delitem__(self, key: str) -> None:
del self._map[namespaced_id(key, self._ns)]
del self._map[self._ns(key)]

# Allow access of values as attributes.
def __setattr__(self, attr: str, value: Value[Any]) -> None:
Expand Down Expand Up @@ -797,11 +840,17 @@ class Outputs:
for type checking reasons).
"""

def __init__(self, session: Session) -> None:
self._effects: Dict[str, Effect_] = {}
self._suspend_when_hidden: Dict[str, bool] = {}
self._session: Session = session
self._ns: Optional[str] = None # Only relevant for ModuleOutputs()
def __init__(
self,
session: Session,
ns: Callable[[str], str] = lambda x: x,
effects: Dict[str, Effect_] = {},
suspend_when_hidden: Dict[str, bool] = {},
) -> None:
self._session = session
self._ns = ns
self._effects = effects
self._suspend_when_hidden = suspend_when_hidden

def __call__(
self,
Expand All @@ -812,7 +861,7 @@ def __call__(
) -> Callable[[render.RenderFunction], None]:
def set_fn(fn: render.RenderFunction) -> None:
# Get the (possibly namespaced) output id
fn_name = namespaced_id(name or fn.__name__, self._ns)
fn_name = self._ns(name or fn.__name__)

# fn is either a regular function or a RenderFunction object. If
# it's the latter, we can give it a bit of metadata, which can be
Expand Down

0 comments on commit c921b8f

Please sign in to comment.