Skip to content

Commit

Permalink
Mark pool-wrapped connection coroutine methods as coroutines
Browse files Browse the repository at this point in the history
Use `markcoroutinefunction` (available in Python 3.12+) to make
`inspect.iscoroutinefunction()` return the correct answer for wrapped
connection methods.

Fixes: #1133
  • Loading branch information
elprans committed Mar 15, 2024
1 parent 1aab209 commit a2055ae
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
7 changes: 7 additions & 0 deletions asyncpg/compat.py
Expand Up @@ -52,6 +52,13 @@ async def wait_closed(stream: asyncio.StreamWriter) -> None:
pass


if sys.version_info < (3, 12):
def markcoroutinefunction(c):
pass
else:
from inspect import markcoroutinefunction # noqa: F401


if sys.version_info < (3, 12):
from ._asyncio_compat import wait_for as wait_for # noqa: F401
else:
Expand Down
8 changes: 6 additions & 2 deletions asyncpg/pool.py
Expand Up @@ -33,7 +33,8 @@ def __new__(mcls, name, bases, dct, *, wrap=False):
if not inspect.isfunction(meth):
continue

wrapper = mcls._wrap_connection_method(attrname)
iscoroutine = inspect.iscoroutinefunction(meth)
wrapper = mcls._wrap_connection_method(attrname, iscoroutine)
wrapper = functools.update_wrapper(wrapper, meth)
dct[attrname] = wrapper

Expand All @@ -43,7 +44,7 @@ def __new__(mcls, name, bases, dct, *, wrap=False):
return super().__new__(mcls, name, bases, dct)

@staticmethod
def _wrap_connection_method(meth_name):
def _wrap_connection_method(meth_name, iscoroutine):
def call_con_method(self, *args, **kwargs):
# This method will be owned by PoolConnectionProxy class.
if self._con is None:
Expand All @@ -55,6 +56,9 @@ def call_con_method(self, *args, **kwargs):
meth = getattr(self._con.__class__, meth_name)
return meth(self._con, *args, **kwargs)

if iscoroutine:
compat.markcoroutinefunction(call_con_method)

return call_con_method


Expand Down

0 comments on commit a2055ae

Please sign in to comment.