Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pytest_pycollect_makeitems #421

Merged
117 changes: 69 additions & 48 deletions pytest_asyncio/plugin.py
Expand Up @@ -24,6 +24,7 @@
)

import pytest
from pytest import Function, Session, Item

if sys.version_info >= (3, 8):
from typing import Literal
Expand Down Expand Up @@ -120,7 +121,7 @@ def fixture(
fixture_function: Optional[FixtureFunction] = None, **kwargs: Any
) -> Union[FixtureFunction, FixtureFunctionMarker]:
if fixture_function is not None:
_set_explicit_asyncio_mark(fixture_function)
_make_asyncio_fixture_function(fixture_function)
return pytest.fixture(fixture_function, **kwargs)

else:
Expand All @@ -132,12 +133,12 @@ def inner(fixture_function: FixtureFunction) -> FixtureFunction:
return inner


def _has_explicit_asyncio_mark(obj: Any) -> bool:
def _is_asyncio_fixture_function(obj: Any) -> bool:
obj = getattr(obj, "__func__", obj) # instance method maybe?
return getattr(obj, "_force_asyncio_fixture", False)


def _set_explicit_asyncio_mark(obj: Any) -> None:
def _make_asyncio_fixture_function(obj: Any) -> None:
if hasattr(obj, "__func__"):
# instance method, check the function object
obj = obj.__func__
Expand Down Expand Up @@ -177,41 +178,51 @@ def pytest_report_header(config: Config) -> List[str]:
return [f"asyncio: mode={mode}"]


def _preprocess_async_fixtures(config: Config, holder: Set[FixtureDef]) -> None:
def _preprocess_async_fixtures(
config: Config,
processed_fixturedefs: Set[FixtureDef],
) -> None:
asyncio_mode = _get_asyncio_mode(config)
fixturemanager = config.pluginmanager.get_plugin("funcmanage")
for fixtures in fixturemanager._arg2fixturedefs.values():
for fixturedef in fixtures:
if fixturedef in holder:
continue
func = fixturedef.func
if not _is_coroutine_or_asyncgen(func):
# Nothing to do with a regular fixture function
if fixturedef in processed_fixturedefs or not _is_coroutine_or_asyncgen(
func
):
continue
if not _is_asyncio_fixture_function(func) and asyncio_mode == Mode.STRICT:
# Ignore async fixtures without explicit asyncio mark in strict mode
# This applies to pytest_trio fixtures, for example
continue
if not _has_explicit_asyncio_mark(func):
if asyncio_mode == Mode.STRICT:
# Ignore async fixtures without explicit asyncio mark in strict mode
# This applies to pytest_trio fixtures, for example
continue
elif asyncio_mode == Mode.AUTO:
# Enforce asyncio mode if 'auto'
_set_explicit_asyncio_mark(func)
_make_asyncio_fixture_function(func)
_inject_fixture_argnames(fixturedef)
_synchronize_async_fixture(fixturedef)
assert _is_asyncio_fixture_function(fixturedef.func)
processed_fixturedefs.add(fixturedef)

to_add = []
for name in ("request", "event_loop"):
if name not in fixturedef.argnames:
to_add.append(name)

if to_add:
fixturedef.argnames += tuple(to_add)
def _inject_fixture_argnames(fixturedef: FixtureDef) -> None:
"""
Ensures that `request` and `event_loop` are arguments of the specified fixture.
"""
to_add = []
for name in ("request", "event_loop"):
if name not in fixturedef.argnames:
to_add.append(name)
if to_add:
fixturedef.argnames += tuple(to_add)

if inspect.isasyncgenfunction(func):
fixturedef.func = _wrap_asyncgen(func)
elif inspect.iscoroutinefunction(func):
fixturedef.func = _wrap_async(func)

assert _has_explicit_asyncio_mark(fixturedef.func)
holder.add(fixturedef)
def _synchronize_async_fixture(fixturedef: FixtureDef) -> None:
"""
Wraps the fixture function of an async fixture in a synchronous function.
"""
func = fixturedef.func
if inspect.isasyncgenfunction(func):
fixturedef.func = _wrap_asyncgen(func)
elif inspect.iscoroutinefunction(func):
fixturedef.func = _wrap_async(func)


def _add_kwargs(
Expand Down Expand Up @@ -281,36 +292,46 @@ async def setup() -> _R:

@pytest.mark.tryfirst
def pytest_pycollect_makeitem(
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
) -> Union[
None, pytest.Item, pytest.Collector, List[Union[pytest.Item, pytest.Collector]]
]:
"""A pytest hook to collect asyncio coroutines."""
if not collector.funcnamefilter(name):
return None
_preprocess_async_fixtures(collector.config, _HOLDER)
if isinstance(obj, staticmethod):
# staticmethods need to be unwrapped.
obj = obj.__func__
if (
_is_coroutine(obj)
or _is_hypothesis_test(obj)
and _hypothesis_test_wraps_coroutine(obj)
):
item = pytest.Function.from_parent(collector, name=name)
marker = item.get_closest_marker("asyncio")
if marker is not None:
return list(collector._genfunctions(name, obj))
else:
if _get_asyncio_mode(item.config) == Mode.AUTO:
# implicitly add asyncio marker if asyncio mode is on
ret = list(collector._genfunctions(name, obj))
for elem in ret:
elem.add_marker("asyncio")
return ret # type: ignore[return-value]
return None


def pytest_collection_modifyitems(
session: Session, config: Config, items: List[Item]
) -> None:
"""
Marks collected async test items as `asyncio` tests.

The mark is only applied in `AUTO` mode. It is applied to:

- coroutines
- staticmethods wrapping coroutines
- Hypothesis tests wrapping coroutines

"""
if _get_asyncio_mode(config) != Mode.AUTO:
return
function_items = (item for item in items if isinstance(item, Function))
for function_item in function_items:
function = function_item.obj
if isinstance(function, staticmethod):
# staticmethods need to be unwrapped.
function = function.__func__
if (
_is_coroutine(function)
or _is_hypothesis_test(function)
and _hypothesis_test_wraps_coroutine(function)
):
function_item.add_marker("asyncio")


def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
return _is_coroutine(function.hypothesis.inner_test)

Expand Down