Skip to content

Commit

Permalink
Add async_stub method (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
PerchunPak committed Jun 24, 2022
1 parent 6030ef2 commit 7cc4cec
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -2,6 +2,12 @@ Releases
========


3.8.0 (unreleased)
------------------
* Add ``MockerFixture.async_mock`` method. Thanks `@PerchunPak`_ for the PR.

.. _@PerchunPak: https://github.com/PerchunPak

3.7.0 (2022-01-28)
------------------

Expand Down
4 changes: 4 additions & 0 deletions docs/usage.rst
Expand Up @@ -110,3 +110,7 @@ It may receive an optional name that is shown in its ``repr``, useful for debugg
foo(stub)
stub.assert_called_once_with('foo', 'bar')
.. seealso::

``async_stub`` method, which actually the same as ``stub`` but makes async stub.
19 changes: 19 additions & 0 deletions src/pytest_mock/plugin.py
Expand Up @@ -27,6 +27,12 @@

_T = TypeVar("_T")

if sys.version_info[:2] > (3, 7):
AsyncMockType = unittest.mock.AsyncMock
else:
import mock
AsyncMockType = mock.AsyncMock


class PytestMockWarning(UserWarning):
"""Base class for all warnings emitted by pytest-mock."""
Expand Down Expand Up @@ -159,6 +165,19 @@ def stub(self, name: Optional[str] = None) -> unittest.mock.MagicMock:
self.mock_module.MagicMock(spec=lambda *args, **kwargs: None, name=name),
)

def async_stub(self, name: Optional[str] = None) -> AsyncMockType:
"""
Create a async stub method. It accepts any arguments. Ideal to register to
callbacks in tests.
:param name: the constructed stub's name as used in repr
:return: Stub object.
"""
return cast(
unittest.mock.AsyncMock,
self.mock_module.AsyncMock(spec=lambda *args, **kwargs: None, name=name),
)

class _Patcher:
"""
Object to provide the same interface as mock.patch, mock.patch.object,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_pytest_mock.py
Expand Up @@ -26,6 +26,9 @@
# Python 3.8 changed the output formatting (bpo-35500), which has been ported to mock 3.0
NEW_FORMATTING = sys.version_info >= (3, 8)

if sys.version_info[:2] >= (3, 8):
from unittest.mock import AsyncMock


@pytest.fixture
def needs_assert_rewrite(pytestconfig):
Expand Down Expand Up @@ -232,6 +235,10 @@ def test_failure_message_with_no_name(self, mocker: MagicMock) -> None:
@pytest.mark.parametrize("name", (None, "", "f", "The Castle of aaarrrrggh"))
def test_failure_message_with_name(self, mocker: MagicMock, name: str) -> None:
self.__test_failure_message(mocker, name=name)

@pytest.mark.skipif(sys.version_info[:2] < (3, 8), reason="This Python version doesn't have `AsyncMock`.")
def test_async_stub_type(self, mocker: MockerFixture) -> None:
assert isinstance(mocker.async_stub(), AsyncMock)


def test_instance_method_spy(mocker: MockerFixture) -> None:
Expand Down

0 comments on commit 7cc4cec

Please sign in to comment.