Skip to content

Commit

Permalink
bpo-38108: Makes mock objects inherit from Base (GH-16060)
Browse files Browse the repository at this point in the history
Backports: 9a7d9519506ae807ca48ff02e2ea117ebac3450e
Signed-off-by: Chris Withers <chris@withers.org>
  • Loading branch information
lisroach authored and cjw296 committed Jan 22, 2020
1 parent 1c9c843 commit 7b643e5
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 57 deletions.
2 changes: 2 additions & 0 deletions NEWS.d/2019-09-25-21-37-02.bpo-38108.Jr9HU6.rst
@@ -0,0 +1,2 @@
Any synchronous magic methods on an AsyncMock now return a MagicMock. Any
asynchronous magic methods on a MagicMock now return an AsyncMock.
54 changes: 19 additions & 35 deletions mock/mock.py
Expand Up @@ -411,7 +411,7 @@ def __new__(cls, *args, **kw):
if spec_arg and _is_async_obj(spec_arg):
bases = (AsyncMockMixin, cls)
new = type(cls.__name__, bases, {'__doc__': cls.__doc__})
instance = object.__new__(new)
instance = _safe_super(NonCallableMock, cls).__new__(new)
return instance


Expand Down Expand Up @@ -997,17 +997,18 @@ def _get_child_mock(self, **kw):

_type = type(self)
if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
# Any asynchronous magic becomes an AsyncMock
klass = AsyncMock
elif _new_name in _sync_async_magics:
# Special case these ones b/c users will assume they are async,
# but they are actually sync (ie. __aiter__)
klass = MagicMock
elif issubclass(_type, AsyncMockMixin):
klass = AsyncMock
if _new_name in _all_sync_magics:
# Any synchronous magic becomes a MagicMock
klass = MagicMock
else:
klass = AsyncMock
elif not issubclass(_type, CallableMixin):
if issubclass(_type, NonCallableMagicMock):
klass = MagicMock
elif issubclass(_type, NonCallableMock) :
elif issubclass(_type, NonCallableMock):
klass = Mock
else:
klass = _type.__mro__[1]
Expand Down Expand Up @@ -1895,6 +1896,7 @@ def _patch_stopall():
"round trunc floor ceil "
"bool next "
"fspath "
"aiter "
)

if IS_PYPY:
Expand Down Expand Up @@ -2037,21 +2039,22 @@ def _set_return_value(mock, method, name):



class MagicMixin(object):
class MagicMixin(Base):
def __init__(self, *args, **kw):
self._mock_set_magics() # make magic work for kwargs in init
_safe_super(MagicMixin, self).__init__(*args, **kw)
self._mock_set_magics() # fix magic broken by upper level init


def _mock_set_magics(self):
these_magics = _magics
orig_magics = _magics | _async_method_magics
these_magics = orig_magics

if getattr(self, "_mock_methods", None) is not None:
these_magics = _magics.intersection(self._mock_methods)
these_magics = orig_magics.intersection(self._mock_methods)

remove_magics = set()
remove_magics = _magics - these_magics
remove_magics = orig_magics - these_magics

for entry in remove_magics:
if entry in type(self).__dict__:
Expand Down Expand Up @@ -2079,33 +2082,14 @@ def mock_add_spec(self, spec, spec_set=False):
self._mock_set_magics()


class AsyncMagicMixin:
class AsyncMagicMixin(MagicMixin):
def __init__(self, *args, **kw):
self._mock_set_async_magics() # make magic work for kwargs in init
self._mock_set_magics() # make magic work for kwargs in init
_safe_super(AsyncMagicMixin, self).__init__(*args, **kw)
self._mock_set_async_magics() # fix magic broken by upper level init

def _mock_set_async_magics(self):
these_magics = _async_magics

if getattr(self, "_mock_methods", None) is not None:
these_magics = _async_magics.intersection(self._mock_methods)
remove_magics = _async_magics - these_magics

for entry in remove_magics:
if entry in type(self).__dict__:
# remove unneeded magic methods
delattr(self, entry)

# don't overwrite existing attributes if called a second time
these_magics = these_magics - set(type(self).__dict__)

_type = type(self)
for entry in these_magics:
setattr(_type, entry, MagicProxy(entry, self))
self._mock_set_magics() # fix magic broken by upper level init


class MagicMock(MagicMixin, AsyncMagicMixin, Mock):
class MagicMock(MagicMixin, Mock):
"""
MagicMock is a subclass of Mock with default implementations
of most of the magic methods. You can use MagicMock without having to
Expand All @@ -2127,7 +2111,7 @@ def mock_add_spec(self, spec, spec_set=False):



class MagicProxy(object):
class MagicProxy(Base):
def __init__(self, name, parent):
self.name = name
self.parent = parent
Expand Down
57 changes: 38 additions & 19 deletions mock/tests/testasync.py
Expand Up @@ -393,6 +393,43 @@ def test_add_side_effect_iterable(self):
RuntimeError('coroutine raised StopIteration')
)

class AsyncMagicMethods(unittest.TestCase):
def test_async_magic_methods_return_async_mocks(self):
m_mock = MagicMock()
self.assertIsInstance(m_mock.__aenter__, AsyncMock)
self.assertIsInstance(m_mock.__aexit__, AsyncMock)
self.assertIsInstance(m_mock.__anext__, AsyncMock)
# __aiter__ is actually a synchronous object
# so should return a MagicMock
self.assertIsInstance(m_mock.__aiter__, MagicMock)

def test_sync_magic_methods_return_magic_mocks(self):
a_mock = AsyncMock()
self.assertIsInstance(a_mock.__enter__, MagicMock)
self.assertIsInstance(a_mock.__exit__, MagicMock)
self.assertIsInstance(a_mock.__next__, MagicMock)
self.assertIsInstance(a_mock.__len__, MagicMock)

def test_magicmock_has_async_magic_methods(self):
m_mock = MagicMock()
self.assertTrue(hasattr(m_mock, "__aenter__"))
self.assertTrue(hasattr(m_mock, "__aexit__"))
self.assertTrue(hasattr(m_mock, "__anext__"))

def test_asyncmock_has_sync_magic_methods(self):
a_mock = AsyncMock()
self.assertTrue(hasattr(a_mock, "__enter__"))
self.assertTrue(hasattr(a_mock, "__exit__"))
self.assertTrue(hasattr(a_mock, "__next__"))
self.assertTrue(hasattr(a_mock, "__len__"))

def test_magic_methods_are_async_functions(self):
m_mock = MagicMock()
self.assertIsInstance(m_mock.__aenter__, AsyncMock)
self.assertIsInstance(m_mock.__aexit__, AsyncMock)
# AsyncMocks are also coroutine functions
self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aenter__))
self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aexit__))

class AsyncContextManagerTest(unittest.TestCase):

Expand Down Expand Up @@ -420,24 +457,6 @@ async def main(self):
val = await response.json()
return val

def test_async_magic_methods_are_async_mocks_with_magicmock(self):
cm_mock = MagicMock(self.WithAsyncContextManager())
self.assertIsInstance(cm_mock.__aenter__, AsyncMock)
self.assertIsInstance(cm_mock.__aexit__, AsyncMock)

def test_magicmock_has_async_magic_methods(self):
cm = MagicMock(name='magic_cm')
self.assertTrue(hasattr(cm, "__aenter__"))
self.assertTrue(hasattr(cm, "__aexit__"))

def test_magic_methods_are_async_functions(self):
cm = MagicMock(name='magic_cm')
self.assertIsInstance(cm.__aenter__, AsyncMock)
self.assertIsInstance(cm.__aexit__, AsyncMock)
# AsyncMocks are also coroutine functions
self.assertTrue(asyncio.iscoroutinefunction(cm.__aenter__))
self.assertTrue(asyncio.iscoroutinefunction(cm.__aexit__))

def test_set_return_value_of_aenter(self):
def inner_test(mock_type):
pc = self.ProductionCode()
Expand Down Expand Up @@ -909,7 +928,7 @@ def test_assert_has_awaits_not_matching_spec_error(self):
async def f(x=None): pass

self.mock = AsyncMock(spec=f)
asyncio.run(self._runnable_test(1))
run(self._runnable_test(1))

with self.assertRaisesRegex(
AssertionError,
Expand Down
3 changes: 0 additions & 3 deletions mock/tests/testmagicmethods.py
Expand Up @@ -271,9 +271,6 @@ def test_magic_mock_equality(self):
self.assertEqual(mock == mock, True)
self.assertEqual(mock != mock, False)


# This should be fixed with issue38163
@unittest.expectedFailure
def test_asyncmock_defaults(self):
mock = AsyncMock()
self.assertEqual(int(mock), 1)
Expand Down

0 comments on commit 7b643e5

Please sign in to comment.