Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
copalco committed Oct 30, 2023
1 parent 540505d commit eb6f42f
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 59 deletions.
157 changes: 128 additions & 29 deletions falcon/hooks.py
Expand Up @@ -29,18 +29,101 @@
import falcon as wsgi
from falcon import asgi

ResourceT = t.TypeVar('ResourceT', bound=type)

class ResourceClass(type, t.Generic[ResourceT]):
def __new__(
cls: t.Type[ResourceT],
name: str,
bases: t.Tuple[type],
attrs: t.Dict[str, t.Any],
) -> ResourceT:
if not any(
hasattr(cls, method)
for method in (
'on_head',
'on_get',
'on_post',
'on_put',
'on_delete',
'on_connect',
'on_options',
'on_trace',
'on_patch',
)
):
raise TypeError(
f"Can't instantiate class '{cls.__name__}', "
+ 'without at least one of the methods '
+ 'on_head',
'on_get',
'on_post',
'on_put',
'on_delete',
'on_connect',
'on_options',
'on_trace',
'on_patch',
)
return cls(name, bases, attrs)

class Resource(t.Protocol, metaclass=ResourceClass):
...

class SyncResponder(t.Protocol):
@staticmethod
def __call__(
self: SyncResponder, req: wsgi.Request, resp: wsgi.Response, **kwargs: t.Any
) -> None:
...

class AsyncResponder(t.Protocol):
@staticmethod
async def __call__(
self: AsyncResponder,
req: asgi.Request,
resp: asgi.Response,
**kwargs: t.Any,
) -> None:
...

Responder = t.Union[SyncResponder, AsyncResponder]
ResponderOrResource = t.Union[Resource, SyncResponder, AsyncResponder]

class SynchronousAction(t.Protocol):
@staticmethod
def __call__(
req: wsgi.Request,
resp: wsgi.Response,
resource: SyncResponder,
params: t.Mapping[str, str],
*args: t.Any,
**kwargs: t.Any,
) -> None:
...

class AsynchronousAction(t.Protocol):
@staticmethod
async def __call__(
req: asgi.Request,
resp: asgi.Response,
resource: AsyncResponder,
params: t.Mapping[str, str],
*args: t.Any,
**kwargs: t.Any,
) -> None:
...

Action = t.Union[SynchronousAction, AsynchronousAction]

_DECORABLE_METHOD_NAME = re.compile(
r'^on_({})(_\w+)?$'.format('|'.join(method.lower() for method in COMBINED_METHODS))
)

SynchronousResource = t.Callable[..., t.Any]
AsynchronousResource = t.Callable[..., t.Awaitable[t.Any]]
Resource = t.Union[SynchronousResource, AsynchronousResource]


def before(
action: Resource, *args: t.Any, is_async: bool = False, **kwargs: t.Any
) -> t.Callable[[Resource], Resource]:
action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any
) -> t.Callable[[ResponderOrResource], ResponderOrResource]:
"""Execute the given action function *before* the responder.
The `params` argument that is passed to the hook
Expand Down Expand Up @@ -90,17 +173,20 @@ def do_something(req, resp, resource, params):
*action*.
"""

def _before(responder_or_resource: Resource) -> Resource:
def _before(responder_or_resource: ResponderOrResource) -> ResponderOrResource:
if isinstance(responder_or_resource, type):
resource = responder_or_resource
resource = t.cast(Resource, responder_or_resource)

for responder_name, responder in getmembers(resource, callable):
if _DECORABLE_METHOD_NAME.match(responder_name):
# This pattern is necessary to capture the current value of
# responder in the do_before_all closure; otherwise, they
# will capture the same responder variable that is shared
# between iterations of the for loop, above.
def let(responder: Resource = responder) -> None:

def let(
responder: Responder = t.cast(Responder, responder)
) -> None:
do_before_all = _wrap_with_before(
responder, action, args, kwargs, is_async
)
Expand All @@ -112,7 +198,7 @@ def let(responder: Resource = responder) -> None:
return resource

else:
responder = responder_or_resource
responder = t.cast(SyncResponder, responder_or_resource)
do_before_one = _wrap_with_before(responder, action, args, kwargs, is_async)

return do_before_one
Expand All @@ -121,8 +207,8 @@ def let(responder: Resource = responder) -> None:


def after(
action: Resource, *args: t.Any, is_async: bool = False, **kwargs: t.Any
) -> t.Callable[[Resource], Resource]:
action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any
) -> t.Callable[[ResponderOrResource], ResponderOrResource]:
"""Execute the given action function *after* the responder.
Args:
Expand Down Expand Up @@ -155,14 +241,16 @@ def after(
*action*.
"""

def _after(responder_or_resource: Resource) -> Resource:
def _after(responder_or_resource: ResponderOrResource) -> ResponderOrResource:
if isinstance(responder_or_resource, type):
resource = responder_or_resource
resource = t.cast(Resource, responder_or_resource)

for responder_name, responder in getmembers(resource, callable):
if _DECORABLE_METHOD_NAME.match(responder_name):

def let(responder: Resource = responder) -> None:
def let(
responder: Responder = t.cast(Responder, responder)
) -> None:
do_after_all = _wrap_with_after(
responder, action, args, kwargs, is_async
)
Expand All @@ -174,7 +262,7 @@ def let(responder: Resource = responder) -> None:
return resource

else:
responder = responder_or_resource
responder = t.cast(Responder, responder_or_resource)
do_after_one = _wrap_with_after(responder, action, args, kwargs, is_async)

return do_after_one
Expand All @@ -188,12 +276,12 @@ def let(responder: Resource = responder) -> None:


def _wrap_with_after(
responder: Resource,
action: Resource,
responder: Responder,
action: Action,
action_args: t.Any,
action_kwargs: t.Any,
is_async: bool,
) -> Resource:
) -> Responder:
"""Execute the given action function after a responder method.
Args:
Expand All @@ -219,10 +307,12 @@ def _wrap_with_after(
async_action = _wrap_non_coroutine_unsafe(action)
else:
async_action = action
async_action = t.cast(AsynchronousAction, async_action)
async_responder = t.cast(AsyncResponder, responder)

@wraps(responder)
async def do_after(
self: Resource,
self: AsyncResponder,
req: asgi.Request,
resp: asgi.Response,
*args: t.Any,
Expand All @@ -231,15 +321,17 @@ async def do_after(
if args:
_merge_responder_args(args, kwargs, extra_argnames)

await responder(self, req, resp, **kwargs)
assert async_action
await async_responder(self, req, resp, **kwargs)
assert async_action, "Needed for type checking, it'll never be None"
await async_action(req, resp, self, *action_args, **action_kwargs)

else:
action = t.cast(SynchronousAction, action)
responder = t.cast(SyncResponder, responder)

@wraps(responder)
def do_after(
self: Resource,
self: SyncResponder,
req: wsgi.Request,
resp: wsgi.Response,
*args: t.Any,
Expand All @@ -249,18 +341,19 @@ def do_after(
_merge_responder_args(args, kwargs, extra_argnames)

responder(self, req, resp, **kwargs)
assert action, "Needed for type checking, it'll never be None"
action(req, resp, self, *action_args, **action_kwargs)

return do_after


def _wrap_with_before(
responder: Resource,
action: Resource,
responder: Responder,
action: Action,
action_args: t.Tuple[t.Any, ...],
action_kwargs: t.Dict[str, t.Any],
is_async: bool,
) -> t.Union[t.Callable[..., t.Awaitable[None]], t.Callable[..., None]]:
) -> Responder:
"""Execute the given action function before a responder method.
Args:
Expand All @@ -287,9 +380,12 @@ def _wrap_with_before(
else:
async_action = action

async_action = t.cast(AsynchronousAction, async_action)
responder = t.cast(AsyncResponder, responder)

@wraps(responder)
async def do_before(
self: Resource,
self: AsyncResponder,
req: asgi.Request,
resp: asgi.Response,
*args: t.Any,
Expand All @@ -298,15 +394,17 @@ async def do_before(
if args:
_merge_responder_args(args, kwargs, extra_argnames)

assert async_action
assert async_action, "Needed for type checking, it'll never be None"
await async_action(req, resp, self, kwargs, *action_args, **action_kwargs)
await responder(self, req, resp, **kwargs)

else:
action = t.cast(SynchronousAction, action)
responder = t.cast(SyncResponder, responder)

@wraps(responder)
def do_before(
self: Resource,
self: SyncResponder,
req: wsgi.Request,
resp: wsgi.Response,
*args: t.Any,
Expand All @@ -315,6 +413,7 @@ def do_before(
if args:
_merge_responder_args(args, kwargs, extra_argnames)

assert action, "Needed for type checking, it'll never be None"
action(req, resp, self, kwargs, *action_args, **action_kwargs)
responder(self, req, resp, **kwargs)

Expand Down

0 comments on commit eb6f42f

Please sign in to comment.