Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parametrized event_loop fixture #278

Merged
merged 4 commits into from Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.rst
Expand Up @@ -261,6 +261,7 @@ Changelog
~~~~~~~~~~~~~~~~~~~

- Raise a warning if @pytest.mark.asyncio is applied to non-async function. `#275 <https://github.com/pytest-dev/pytest-asyncio/issues/275>`_
- Support parametrized ``event_loop`` fixture. `#278 <https://github.com/pytest-dev/pytest-asyncio/issues/278>`_

0.17.2 (22-01-17)
~~~~~~~~~~~~~~~~~~~
Expand Down
229 changes: 114 additions & 115 deletions pytest_asyncio/plugin.py
Expand Up @@ -165,7 +165,7 @@ def _set_explicit_asyncio_mark(obj: Any) -> None:

def _is_coroutine(obj: Any) -> bool:
"""Check to see if an object is really an asyncio coroutine."""
return asyncio.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)
return asyncio.iscoroutinefunction(obj)


def _is_coroutine_or_asyncgen(obj: Any) -> bool:
Expand Down Expand Up @@ -198,6 +198,118 @@ def pytest_report_header(config: Config) -> List[str]:
return [f"asyncio: mode={mode}"]


def _preprocess_async_fixtures(config: Config, holder: Set[FixtureDef]) -> None:
asyncio_mode = _get_asyncio_mode(config)
fixturemanager = config.pluginmanager.get_plugin("funcmanage")
for fixtures in fixturemanager._arg2fixturedefs.values():
for fixturedef in fixtures:
if fixturedef is holder:
continue
func = fixturedef.func
if not _is_coroutine_or_asyncgen(func):
# Nothing to do with a regular fixture function
continue
if not _has_explicit_asyncio_mark(func):
if asyncio_mode == Mode.AUTO:
# Enforce asyncio mode if 'auto'
_set_explicit_asyncio_mark(func)
elif asyncio_mode == Mode.LEGACY:
_set_explicit_asyncio_mark(func)
try:
code = func.__code__
except AttributeError:
code = func.__func__.__code__
name = (
f"<fixture {func.__qualname__}, file={code.co_filename}, "
f"line={code.co_firstlineno}>"
)
warnings.warn(
LEGACY_ASYNCIO_FIXTURE.format(name=name),
DeprecationWarning,
)

to_add = []
for name in ("request", "event_loop"):
if name not in fixturedef.argnames:
to_add.append(name)

if to_add:
fixturedef.argnames += tuple(to_add)

if inspect.isasyncgenfunction(func):
fixturedef.func = _wrap_asyncgen(func)
elif inspect.iscoroutinefunction(func):
fixturedef.func = _wrap_async(func)

assert _has_explicit_asyncio_mark(fixturedef.func)
holder.add(fixturedef)


def _add_kwargs(
func: Callable[..., Any],
kwargs: Dict[str, Any],
event_loop: asyncio.AbstractEventLoop,
request: SubRequest,
) -> Dict[str, Any]:
sig = inspect.signature(func)
ret = kwargs.copy()
if "request" in sig.parameters:
ret["request"] = request
if "event_loop" in sig.parameters:
ret["event_loop"] = event_loop
return ret


def _wrap_asyncgen(func: Callable[..., AsyncIterator[_R]]) -> Callable[..., _R]:
@functools.wraps(func)
def _asyncgen_fixture_wrapper(
event_loop: asyncio.AbstractEventLoop, request: SubRequest, **kwargs: Any
) -> _R:
gen_obj = func(**_add_kwargs(func, kwargs, event_loop, request))

async def setup() -> _R:
res = await gen_obj.__anext__()
return res

def finalizer() -> None:
"""Yield again, to finalize."""

async def async_finalizer() -> None:
try:
await gen_obj.__anext__()
except StopAsyncIteration:
pass
else:
msg = "Async generator fixture didn't stop."
msg += "Yield only once."
raise ValueError(msg)

event_loop.run_until_complete(async_finalizer())

result = event_loop.run_until_complete(setup())
request.addfinalizer(finalizer)
return result

return _asyncgen_fixture_wrapper


def _wrap_async(func: Callable[..., Awaitable[_R]]) -> Callable[..., _R]:
@functools.wraps(func)
def _async_fixture_wrapper(
event_loop: asyncio.AbstractEventLoop, request: SubRequest, **kwargs: Any
) -> _R:
async def setup() -> _R:
res = await func(**_add_kwargs(func, kwargs, event_loop, request))
return res

return event_loop.run_until_complete(setup())

return _async_fixture_wrapper


_HOLDER: Set[FixtureDef] = set()


@pytest.mark.tryfirst
def pytest_pycollect_makeitem(
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
Expand All @@ -212,6 +324,7 @@ def pytest_pycollect_makeitem(
or _is_hypothesis_test(obj)
and _hypothesis_test_wraps_coroutine(obj)
):
_preprocess_async_fixtures(collector.config, _HOLDER)
item = pytest.Function.from_parent(collector, name=name)
marker = item.get_closest_marker("asyncio")
if marker is not None:
Expand All @@ -230,31 +343,6 @@ def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
return _is_coroutine(function.hypothesis.inner_test)


class FixtureStripper:
"""Include additional Fixture, and then strip them"""

EVENT_LOOP = "event_loop"

def __init__(self, fixturedef: FixtureDef) -> None:
self.fixturedef = fixturedef
self.to_strip: Set[str] = set()

def add(self, name: str) -> None:
"""Add fixture name to fixturedef
and record in to_strip list (If not previously included)"""
if name in self.fixturedef.argnames:
return
self.fixturedef.argnames += (name,)
self.to_strip.add(name)

def get_and_strip_from(self, name: str, data_dict: Dict[str, _T]) -> _T:
"""Strip name from data, and return value"""
result = data_dict[name]
if name in self.to_strip:
del data_dict[name]
return result


@pytest.hookimpl(trylast=True)
def pytest_fixture_post_finalizer(fixturedef: FixtureDef, request: SubRequest) -> None:
"""Called after fixture teardown"""
Expand Down Expand Up @@ -291,95 +379,6 @@ def pytest_fixture_setup(
policy.set_event_loop(loop)
return

func = fixturedef.func
if not _is_coroutine_or_asyncgen(func):
# Nothing to do with a regular fixture function
yield
return

config = request.node.config
asyncio_mode = _get_asyncio_mode(config)

if not _has_explicit_asyncio_mark(func):
if asyncio_mode == Mode.AUTO:
# Enforce asyncio mode if 'auto'
_set_explicit_asyncio_mark(func)
elif asyncio_mode == Mode.LEGACY:
_set_explicit_asyncio_mark(func)
try:
code = func.__code__
except AttributeError:
code = func.__func__.__code__
name = (
f"<fixture {func.__qualname__}, file={code.co_filename}, "
f"line={code.co_firstlineno}>"
)
warnings.warn(
LEGACY_ASYNCIO_FIXTURE.format(name=name),
DeprecationWarning,
)
else:
# asyncio_mode is STRICT,
# don't handle fixtures that are not explicitly marked
yield
return

if inspect.isasyncgenfunction(func):
# This is an async generator function. Wrap it accordingly.
generator = func

fixture_stripper = FixtureStripper(fixturedef)
fixture_stripper.add(FixtureStripper.EVENT_LOOP)

def wrapper(*args, **kwargs):
loop = fixture_stripper.get_and_strip_from(
FixtureStripper.EVENT_LOOP, kwargs
)

gen_obj = generator(*args, **kwargs)

async def setup():
res = await gen_obj.__anext__()
return res

def finalizer():
"""Yield again, to finalize."""

async def async_finalizer():
try:
await gen_obj.__anext__()
except StopAsyncIteration:
pass
else:
msg = "Async generator fixture didn't stop."
msg += "Yield only once."
raise ValueError(msg)

loop.run_until_complete(async_finalizer())

result = loop.run_until_complete(setup())
request.addfinalizer(finalizer)
return result

fixturedef.func = wrapper
elif inspect.iscoroutinefunction(func):
coro = func

fixture_stripper = FixtureStripper(fixturedef)
fixture_stripper.add(FixtureStripper.EVENT_LOOP)

def wrapper(*args, **kwargs):
loop = fixture_stripper.get_and_strip_from(
FixtureStripper.EVENT_LOOP, kwargs
)

async def setup():
res = await coro(*args, **kwargs)
return res

return loop.run_until_complete(setup())

fixturedef.func = wrapper
yield


Expand Down
31 changes: 31 additions & 0 deletions tests/async_fixtures/test_parametrized_loop.py
@@ -0,0 +1,31 @@
import asyncio

import pytest

TESTS_COUNT = 0


def teardown_module():
# parametrized 2 * 2 times: 2 for 'event_loop' and 2 for 'fix'
assert TESTS_COUNT == 4


@pytest.fixture(scope="module", params=[1, 2])
def event_loop(request):
request.param
loop = asyncio.new_event_loop()
yield loop
loop.close()


@pytest.fixture(params=["a", "b"])
async def fix(request):
await asyncio.sleep(0)
return request.param


@pytest.mark.asyncio
async def test_parametrized_loop(fix):
await asyncio.sleep(0)
global TESTS_COUNT
TESTS_COUNT += 1