diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 6fd300d5..b3dc9d67 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -24,6 +24,7 @@ ) import pytest +from pytest import Function, Session, Item if sys.version_info >= (3, 8): from typing import Literal @@ -120,7 +121,7 @@ def fixture( fixture_function: Optional[FixtureFunction] = None, **kwargs: Any ) -> Union[FixtureFunction, FixtureFunctionMarker]: if fixture_function is not None: - _set_explicit_asyncio_mark(fixture_function) + _make_asyncio_fixture_function(fixture_function) return pytest.fixture(fixture_function, **kwargs) else: @@ -132,12 +133,12 @@ def inner(fixture_function: FixtureFunction) -> FixtureFunction: return inner -def _has_explicit_asyncio_mark(obj: Any) -> bool: +def _is_asyncio_fixture_function(obj: Any) -> bool: obj = getattr(obj, "__func__", obj) # instance method maybe? return getattr(obj, "_force_asyncio_fixture", False) -def _set_explicit_asyncio_mark(obj: Any) -> None: +def _make_asyncio_fixture_function(obj: Any) -> None: if hasattr(obj, "__func__"): # instance method, check the function object obj = obj.__func__ @@ -177,41 +178,51 @@ def pytest_report_header(config: Config) -> List[str]: return [f"asyncio: mode={mode}"] -def _preprocess_async_fixtures(config: Config, holder: Set[FixtureDef]) -> None: +def _preprocess_async_fixtures( + config: Config, + processed_fixturedefs: Set[FixtureDef], +) -> None: asyncio_mode = _get_asyncio_mode(config) fixturemanager = config.pluginmanager.get_plugin("funcmanage") for fixtures in fixturemanager._arg2fixturedefs.values(): for fixturedef in fixtures: - if fixturedef in holder: - continue func = fixturedef.func - if not _is_coroutine_or_asyncgen(func): - # Nothing to do with a regular fixture function + if fixturedef in processed_fixturedefs or not _is_coroutine_or_asyncgen( + func + ): + continue + if not _is_asyncio_fixture_function(func) and asyncio_mode == Mode.STRICT: + # Ignore async fixtures without explicit asyncio mark in strict mode + # This applies to pytest_trio fixtures, for example continue - if not _has_explicit_asyncio_mark(func): - if asyncio_mode == Mode.STRICT: - # Ignore async fixtures without explicit asyncio mark in strict mode - # This applies to pytest_trio fixtures, for example - continue - elif asyncio_mode == Mode.AUTO: - # Enforce asyncio mode if 'auto' - _set_explicit_asyncio_mark(func) + _make_asyncio_fixture_function(func) + _inject_fixture_argnames(fixturedef) + _synchronize_async_fixture(fixturedef) + assert _is_asyncio_fixture_function(fixturedef.func) + processed_fixturedefs.add(fixturedef) - to_add = [] - for name in ("request", "event_loop"): - if name not in fixturedef.argnames: - to_add.append(name) - if to_add: - fixturedef.argnames += tuple(to_add) +def _inject_fixture_argnames(fixturedef: FixtureDef) -> None: + """ + Ensures that `request` and `event_loop` are arguments of the specified fixture. + """ + to_add = [] + for name in ("request", "event_loop"): + if name not in fixturedef.argnames: + to_add.append(name) + if to_add: + fixturedef.argnames += tuple(to_add) - if inspect.isasyncgenfunction(func): - fixturedef.func = _wrap_asyncgen(func) - elif inspect.iscoroutinefunction(func): - fixturedef.func = _wrap_async(func) - assert _has_explicit_asyncio_mark(fixturedef.func) - holder.add(fixturedef) +def _synchronize_async_fixture(fixturedef: FixtureDef) -> None: + """ + Wraps the fixture function of an async fixture in a synchronous function. + """ + func = fixturedef.func + if inspect.isasyncgenfunction(func): + fixturedef.func = _wrap_asyncgen(func) + elif inspect.iscoroutinefunction(func): + fixturedef.func = _wrap_async(func) def _add_kwargs( @@ -281,7 +292,7 @@ async def setup() -> _R: @pytest.mark.tryfirst def pytest_pycollect_makeitem( - collector: Union[pytest.Module, pytest.Class], name: str, obj: object + collector: Union[pytest.Module, pytest.Class], name: str, obj: object ) -> Union[ None, pytest.Item, pytest.Collector, List[Union[pytest.Item, pytest.Collector]] ]: @@ -289,28 +300,38 @@ def pytest_pycollect_makeitem( if not collector.funcnamefilter(name): return None _preprocess_async_fixtures(collector.config, _HOLDER) - if isinstance(obj, staticmethod): - # staticmethods need to be unwrapped. - obj = obj.__func__ - if ( - _is_coroutine(obj) - or _is_hypothesis_test(obj) - and _hypothesis_test_wraps_coroutine(obj) - ): - item = pytest.Function.from_parent(collector, name=name) - marker = item.get_closest_marker("asyncio") - if marker is not None: - return list(collector._genfunctions(name, obj)) - else: - if _get_asyncio_mode(item.config) == Mode.AUTO: - # implicitly add asyncio marker if asyncio mode is on - ret = list(collector._genfunctions(name, obj)) - for elem in ret: - elem.add_marker("asyncio") - return ret # type: ignore[return-value] return None +def pytest_collection_modifyitems( + session: Session, config: Config, items: List[Item] +) -> None: + """ + Marks collected async test items as `asyncio` tests. + + The mark is only applied in `AUTO` mode. It is applied to: + + - coroutines + - staticmethods wrapping coroutines + - Hypothesis tests wrapping coroutines + + """ + if _get_asyncio_mode(config) != Mode.AUTO: + return + function_items = (item for item in items if isinstance(item, Function)) + for function_item in function_items: + function = function_item.obj + if isinstance(function, staticmethod): + # staticmethods need to be unwrapped. + function = function.__func__ + if ( + _is_coroutine(function) + or _is_hypothesis_test(function) + and _hypothesis_test_wraps_coroutine(function) + ): + function_item.add_marker("asyncio") + + def _hypothesis_test_wraps_coroutine(function: Any) -> bool: return _is_coroutine(function.hypothesis.inner_test)