Skip to content

Commit

Permalink
Provide typing info (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Jan 15, 2022
1 parent 0a3328f commit 308f308
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 42 deletions.
5 changes: 3 additions & 2 deletions Makefile
Expand Up @@ -23,10 +23,11 @@ clean-test: ## remove test and coverage artifacts
lint:
# CI env-var is set by GitHub actions
ifdef CI
pre-commit run --all-files --show-diff-on-failure
python -m pre_commit run --all-files --show-diff-on-failure
else
pre-commit run --all-files
python -m pre_commit run --all-files
endif
python -m mypy pytest_asyncio --show-error-codes

test:
coverage run -m pytest tests
Expand Down
1 change: 1 addition & 0 deletions README.rst
Expand Up @@ -260,6 +260,7 @@ Changelog
~~~~~~~~~~~~~~~~~~~
- Fixes a bug that prevents async Hypothesis tests from working without explicit ``asyncio`` marker when ``--asyncio-mode=auto`` is set. `#258 <https://github.com/pytest-dev/pytest-asyncio/issues/258>`_
- Fixed a bug that closes the default event loop if the loop doesn't exist `#257 <https://github.com/pytest-dev/pytest-asyncio/issues/257>`_
- Added type annotations. `#198 <https://github.com/pytest-dev/pytest-asyncio/issues/198>`_

0.17.0 (22-01-13)
~~~~~~~~~~~~~~~~~~~
Expand Down
166 changes: 127 additions & 39 deletions pytest_asyncio/plugin.py
Expand Up @@ -6,8 +6,45 @@
import inspect
import socket
import warnings
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
TypeVar,
Union,
cast,
overload,
)

import pytest
from typing_extensions import Literal

_R = TypeVar("_R")

_ScopeName = Literal["session", "package", "module", "class", "function"]
_T = TypeVar("_T")

SimpleFixtureFunction = TypeVar(
"SimpleFixtureFunction", bound=Callable[..., Awaitable[_R]]
)
FactoryFixtureFunction = TypeVar(
"FactoryFixtureFunction", bound=Callable[..., AsyncIterator[_R]]
)
FixtureFunction = Union[SimpleFixtureFunction, FactoryFixtureFunction]
FixtureFunctionMarker = Callable[[FixtureFunction], FixtureFunction]

Config = Any # pytest < 7.0
PytestPluginManager = Any # pytest < 7.0
FixtureDef = Any # pytest < 7.0
Parser = Any # pytest < 7.0
SubRequest = Any # pytest < 7.0


class Mode(str, enum.Enum):
Expand Down Expand Up @@ -41,7 +78,7 @@ class Mode(str, enum.Enum):
"""


def pytest_addoption(parser, pluginmanager):
def pytest_addoption(parser: Parser, pluginmanager: PytestPluginManager) -> None:
group = parser.getgroup("asyncio")
group.addoption(
"--asyncio-mode",
Expand All @@ -58,49 +95,87 @@ def pytest_addoption(parser, pluginmanager):
)


def fixture(fixture_function=None, **kwargs):
@overload
def fixture(
fixture_function: FixtureFunction,
*,
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
params: Optional[Iterable[object]] = ...,
autouse: bool = ...,
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
Callable[[Any], Optional[object]],
]
] = ...,
name: Optional[str] = ...,
) -> FixtureFunction:
...


@overload
def fixture(
fixture_function: None = ...,
*,
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
params: Optional[Iterable[object]] = ...,
autouse: bool = ...,
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
Callable[[Any], Optional[object]],
]
] = ...,
name: Optional[str] = None,
) -> FixtureFunctionMarker:
...


def fixture(
fixture_function: Optional[FixtureFunction] = None, **kwargs: Any
) -> Union[FixtureFunction, FixtureFunctionMarker]:
if fixture_function is not None:
_set_explicit_asyncio_mark(fixture_function)
return pytest.fixture(fixture_function, **kwargs)

else:

@functools.wraps(fixture)
def inner(fixture_function):
def inner(fixture_function: FixtureFunction) -> FixtureFunction:
return fixture(fixture_function, **kwargs)

return inner


def _has_explicit_asyncio_mark(obj):
def _has_explicit_asyncio_mark(obj: Any) -> bool:
obj = getattr(obj, "__func__", obj) # instance method maybe?
return getattr(obj, "_force_asyncio_fixture", False)


def _set_explicit_asyncio_mark(obj):
def _set_explicit_asyncio_mark(obj: Any) -> None:
if hasattr(obj, "__func__"):
# instance method, check the function object
obj = obj.__func__
obj._force_asyncio_fixture = True


def _is_coroutine(obj):
def _is_coroutine(obj: Any) -> bool:
"""Check to see if an object is really an asyncio coroutine."""
return asyncio.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)


def _is_coroutine_or_asyncgen(obj):
def _is_coroutine_or_asyncgen(obj: Any) -> bool:
return _is_coroutine(obj) or inspect.isasyncgenfunction(obj)


def _get_asyncio_mode(config):
def _get_asyncio_mode(config: Config) -> Mode:
val = config.getoption("asyncio_mode")
if val is None:
val = config.getini("asyncio_mode")
return Mode(val)


def pytest_configure(config):
def pytest_configure(config: Config) -> None:
"""Inject documentation."""
config.addinivalue_line(
"markers",
Expand All @@ -113,10 +188,14 @@ def pytest_configure(config):


@pytest.mark.tryfirst
def pytest_pycollect_makeitem(collector, name, obj):
def pytest_pycollect_makeitem(
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
) -> Union[
None, pytest.Item, pytest.Collector, List[Union[pytest.Item, pytest.Collector]]
]:
"""A pytest hook to collect asyncio coroutines."""
if not collector.funcnamefilter(name):
return
return None
if (
_is_coroutine(obj)
or _is_hypothesis_test(obj)
Expand All @@ -131,10 +210,11 @@ def pytest_pycollect_makeitem(collector, name, obj):
ret = list(collector._genfunctions(name, obj))
for elem in ret:
elem.add_marker("asyncio")
return ret
return ret # type: ignore[return-value]
return None


def _hypothesis_test_wraps_coroutine(function):
def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
return _is_coroutine(function.hypothesis.inner_test)


Expand All @@ -144,19 +224,19 @@ class FixtureStripper:
REQUEST = "request"
EVENT_LOOP = "event_loop"

def __init__(self, fixturedef):
def __init__(self, fixturedef: FixtureDef) -> None:
self.fixturedef = fixturedef
self.to_strip = set()
self.to_strip: Set[str] = set()

def add(self, name):
def add(self, name: str) -> None:
"""Add fixture name to fixturedef
and record in to_strip list (If not previously included)"""
if name in self.fixturedef.argnames:
return
self.fixturedef.argnames += (name,)
self.to_strip.add(name)

def get_and_strip_from(self, name, data_dict):
def get_and_strip_from(self, name: str, data_dict: Dict[str, _T]) -> _T:
"""Strip name from data, and return value"""
result = data_dict[name]
if name in self.to_strip:
Expand All @@ -165,7 +245,7 @@ def get_and_strip_from(self, name, data_dict):


@pytest.hookimpl(trylast=True)
def pytest_fixture_post_finalizer(fixturedef, request):
def pytest_fixture_post_finalizer(fixturedef: FixtureDef, request: SubRequest) -> None:
"""Called after fixture teardown"""
if fixturedef.argname == "event_loop":
policy = asyncio.get_event_loop_policy()
Expand All @@ -182,7 +262,9 @@ def pytest_fixture_post_finalizer(fixturedef, request):


@pytest.hookimpl(hookwrapper=True)
def pytest_fixture_setup(fixturedef, request):
def pytest_fixture_setup(
fixturedef: FixtureDef, request: SubRequest
) -> Optional[object]:
"""Adjust the event loop policy when an event loop is produced."""
if fixturedef.argname == "event_loop":
outcome = yield
Expand Down Expand Up @@ -295,39 +377,43 @@ async def setup():


@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_pyfunc_call(pyfuncitem):
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> Optional[object]:
"""
Pytest hook called before a test case is run.
Wraps marked tests in a synchronous function
where the wrapped test coroutine is executed in an event loop.
"""
if "asyncio" in pyfuncitem.keywords:
funcargs: Dict[str, object] = pyfuncitem.funcargs # type: ignore[name-defined]
loop = cast(asyncio.AbstractEventLoop, funcargs["event_loop"])
if _is_hypothesis_test(pyfuncitem.obj):
pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync(
pyfuncitem.obj.hypothesis.inner_test,
_loop=pyfuncitem.funcargs["event_loop"],
_loop=loop,
)
else:
pyfuncitem.obj = wrap_in_sync(
pyfuncitem.obj, _loop=pyfuncitem.funcargs["event_loop"]
pyfuncitem.obj,
_loop=loop,
)
yield


def _is_hypothesis_test(function) -> bool:
def _is_hypothesis_test(function: Any) -> bool:
return getattr(function, "is_hypothesis_test", False)


def wrap_in_sync(func, _loop):
def wrap_in_sync(func: Callable[..., Awaitable[Any]], _loop: asyncio.AbstractEventLoop):
"""Return a sync wrapper around an async function executing it in the
current event loop."""

# if the function is already wrapped, we rewrap using the original one
# not using __wrapped__ because the original function may already be
# a wrapped one
if hasattr(func, "_raw_test_func"):
func = func._raw_test_func
raw_func = getattr(func, "_raw_test_func", None)
if raw_func is not None:
func = raw_func

@functools.wraps(func)
def inner(**kwargs):
Expand All @@ -344,20 +430,22 @@ def inner(**kwargs):
task.exception()
raise

inner._raw_test_func = func
inner._raw_test_func = func # type: ignore[attr-defined]
return inner


def pytest_runtest_setup(item):
def pytest_runtest_setup(item: pytest.Item) -> None:
if "asyncio" in item.keywords:
fixturenames = item.fixturenames # type: ignore[attr-defined]
# inject an event loop fixture for all async tests
if "event_loop" in item.fixturenames:
item.fixturenames.remove("event_loop")
item.fixturenames.insert(0, "event_loop")
if "event_loop" in fixturenames:
fixturenames.remove("event_loop")
fixturenames.insert(0, "event_loop")
obj = item.obj # type: ignore[attr-defined]
if (
item.get_closest_marker("asyncio") is not None
and not getattr(item.obj, "hypothesis", False)
and getattr(item.obj, "is_hypothesis_test", False)
and not getattr(obj, "hypothesis", False)
and getattr(obj, "is_hypothesis_test", False)
):
pytest.fail(
"test function `%r` is using Hypothesis, but pytest-asyncio "
Expand All @@ -366,32 +454,32 @@ def pytest_runtest_setup(item):


@pytest.fixture
def event_loop(request):
def event_loop(request: pytest.FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]:
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


def _unused_port(socket_type):
def _unused_port(socket_type: int) -> int:
"""Find an unused localhost port from 1024-65535 and return it."""
with contextlib.closing(socket.socket(type=socket_type)) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]


@pytest.fixture
def unused_tcp_port():
def unused_tcp_port() -> int:
return _unused_port(socket.SOCK_STREAM)


@pytest.fixture
def unused_udp_port():
def unused_udp_port() -> int:
return _unused_port(socket.SOCK_DGRAM)


@pytest.fixture(scope="session")
def unused_tcp_port_factory():
def unused_tcp_port_factory() -> Callable[[], int]:
"""A factory function, producing different unused TCP ports."""
produced = set()

Expand All @@ -410,7 +498,7 @@ def factory():


@pytest.fixture(scope="session")
def unused_udp_port_factory():
def unused_udp_port_factory() -> Callable[[], int]:
"""A factory function, producing different unused UDP ports."""
produced = set()

Expand Down
Empty file added pytest_asyncio/py.typed
Empty file.
3 changes: 3 additions & 0 deletions setup.cfg
Expand Up @@ -27,6 +27,7 @@ classifiers =

Framework :: AsyncIO
Framework :: Pytest
Typing :: Typed

[options]
python_requires = >=3.7
Expand All @@ -38,12 +39,14 @@ setup_requires =

install_requires =
pytest >= 5.4.0
typing-extensions >= 4.0

[options.extras_require]
testing =
coverage==6.2
hypothesis >= 5.7.1
flaky >= 3.5.0
mypy == 0.931

[options.entry_points]
pytest11 =
Expand Down

0 comments on commit 308f308

Please sign in to comment.