Skip to content

Commit

Permalink
run() implementation for py3.6, where it's missing
Browse files Browse the repository at this point in the history
  • Loading branch information
cjw296 committed Jan 13, 2020
1 parent 78fcb04 commit 3804999
Showing 1 changed file with 61 additions and 48 deletions.
109 changes: 61 additions & 48 deletions mock/tests/testasync.py
@@ -1,3 +1,4 @@

import asyncio
import inspect
import unittest
Expand All @@ -6,6 +7,18 @@
from mock.mock import _AwaitEvent


try:
from asyncio import run
except ImportError:
def run(main):
loop = asyncio.new_event_loop()
try:
return_value = loop.run_until_complete(main)
finally:
loop.close()
return return_value


def tearDownModule():
asyncio.set_event_loop_policy(None)

Expand Down Expand Up @@ -48,13 +61,13 @@ def test_is_async_patch(self):
def test_async(mock_method):
m = mock_method()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

@patch(f'{async_foo_name}.async_method')
def test_no_parent_attribute(mock_method):
m = mock_method()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

test_async()
test_no_parent_attribute()
Expand All @@ -71,7 +84,7 @@ def test_async_def_patch(self):
async def test_async():
self.assertIsInstance(async_func, AsyncMock)

asyncio.run(test_async())
run(test_async())
self.assertTrue(inspect.iscoroutinefunction(async_func))


Expand All @@ -88,7 +101,7 @@ def test_async():
with patch.object(AsyncClass, 'async_method') as mock_method:
m = mock_method()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

test_async()

Expand All @@ -105,7 +118,7 @@ async def test_async():
self.assertIsInstance(async_func, AsyncMock)
self.assertTrue(inspect.iscoroutinefunction(async_func))

asyncio.run(test_async())
run(test_async())


class AsyncMockTest(unittest.TestCase):
Expand All @@ -123,7 +136,7 @@ def test_isawaitable(self):
mock = AsyncMock()
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)
self.assertIn('assert_awaited', dir(mock))

def test_iscoroutinefunction_normal_function(self):
Expand Down Expand Up @@ -172,7 +185,7 @@ async def main():
self.assertIsInstance(spec.awaited, _AwaitEvent)
spec.assert_not_awaited()

asyncio.run(main())
run(main())

self.assertTrue(asyncio.iscoroutinefunction(spec))
self.assertTrue(asyncio.iscoroutine(awaitable))
Expand Down Expand Up @@ -217,7 +230,7 @@ async def test_async():
self.assertIsNone(mock_method.await_args)
self.assertEqual(mock_method.await_args_list, [])

asyncio.run(test_async())
run(test_async())


class AsyncSpecTest(unittest.TestCase):
Expand All @@ -226,42 +239,42 @@ def test_spec_as_async_positional_magicmock(self):
self.assertIsInstance(mock, MagicMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

def test_spec_as_async_kw_magicmock(self):
mock = MagicMock(spec=async_func)
self.assertIsInstance(mock, MagicMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

def test_spec_as_async_kw_AsyncMock(self):
mock = AsyncMock(spec=async_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

def test_spec_as_async_positional_AsyncMock(self):
mock = AsyncMock(async_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

def test_spec_as_normal_kw_AsyncMock(self):
mock = AsyncMock(spec=normal_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

def test_spec_as_normal_positional_AsyncMock(self):
mock = AsyncMock(normal_func)
self.assertIsInstance(mock, AsyncMock)
m = mock()
self.assertTrue(inspect.isawaitable(m))
asyncio.run(m)
run(m)

def test_spec_async_mock(self):
@patch.object(AsyncClass, 'async_method', spec=True)
Expand Down Expand Up @@ -328,7 +341,7 @@ async def addition(self, var):
return var + 1

mock = AsyncMock(addition, return_value=10)
output = asyncio.run(mock(5))
output = run(mock(5))

self.assertEqual(output, 10)

Expand All @@ -337,23 +350,23 @@ async def addition(var):
return var + 1
mock = AsyncMock(addition, side_effect=Exception('err'))
with self.assertRaises(Exception):
asyncio.run(mock(5))
run(mock(5))

def test_add_side_effect_function(self):
async def addition(var):
return var + 1
mock = AsyncMock(side_effect=addition)
result = asyncio.run(mock(5))
result = run(mock(5))
self.assertEqual(result, 6)

def test_add_side_effect_iterable(self):
vals = [1, 2, 3]
mock = AsyncMock(side_effect=vals)
for item in vals:
self.assertEqual(item, asyncio.run(mock()))
self.assertEqual(item, run(mock()))

with self.assertRaises(RuntimeError) as e:
asyncio.run(mock())
run(mock())
self.assertEqual(
e.exception,
RuntimeError('coroutine raised StopIteration')
Expand Down Expand Up @@ -389,7 +402,7 @@ async def use_context_manager():
called = True
return result

result = asyncio.run(use_context_manager())
result = run(use_context_manager())
self.assertFalse(instance.entered)
self.assertFalse(instance.exited)
self.assertTrue(called)
Expand All @@ -411,7 +424,7 @@ async def use_context_manager():
async with mock_instance as result:
return result

self.assertIs(asyncio.run(use_context_manager()), expected_result)
self.assertIs(run(use_context_manager()), expected_result)

def test_mock_customize_async_context_manager_with_coroutine(self):
enter_called = False
Expand All @@ -435,7 +448,7 @@ async def use_context_manager():
async with mock_instance:
pass

asyncio.run(use_context_manager())
run(use_context_manager())
self.assertTrue(enter_called)
self.assertTrue(exit_called)

Expand All @@ -447,7 +460,7 @@ async def raise_in(context_manager):
instance = self.WithAsyncContextManager()
mock_instance = MagicMock(instance)
with self.assertRaises(TypeError):
asyncio.run(raise_in(mock_instance))
run(raise_in(mock_instance))


class AsyncIteratorTest(unittest.TestCase):
Expand Down Expand Up @@ -477,11 +490,11 @@ def test_mock_aiter_and_anext(self):

iterator = instance.__aiter__()
if asyncio.iscoroutine(iterator):
iterator = asyncio.run(iterator)
iterator = run(iterator)

mock_iterator = mock_instance.__aiter__()
if asyncio.iscoroutine(mock_iterator):
mock_iterator = asyncio.run(mock_iterator)
mock_iterator = run(mock_iterator)

self.assertEqual(asyncio.iscoroutine(iterator.__aiter__),
asyncio.iscoroutine(mock_iterator.__aiter__))
Expand All @@ -499,17 +512,17 @@ async def iterate(iterator):
expected = ["FOO", "BAR", "BAZ"]
with self.subTest("iterate through default value"):
mock_instance = MagicMock(self.WithAsyncIterator())
self.assertEqual([], asyncio.run(iterate(mock_instance)))
self.assertEqual([], run(iterate(mock_instance)))

with self.subTest("iterate through set return_value"):
mock_instance = MagicMock(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = expected[:]
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
self.assertEqual(expected, run(iterate(mock_instance)))

with self.subTest("iterate through set return_value iterator"):
mock_instance = MagicMock(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = iter(expected[:])
self.assertEqual(expected, asyncio.run(iterate(mock_instance)))
self.assertEqual(expected, run(iterate(mock_instance)))


class AsyncMockAssert(unittest.TestCase):
Expand All @@ -526,56 +539,56 @@ def test_assert_awaited(self):
with self.assertRaises(AssertionError):
self.mock.assert_awaited()

asyncio.run(self._runnable_test())
run(self._runnable_test())
self.mock.assert_awaited()

def test_assert_awaited_once(self):
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once()

asyncio.run(self._runnable_test())
run(self._runnable_test())
self.mock.assert_awaited_once()

asyncio.run(self._runnable_test())
run(self._runnable_test())
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once()

def test_assert_awaited_with(self):
asyncio.run(self._runnable_test())
run(self._runnable_test())
msg = 'expected await not found'
with self.assertRaisesRegex(AssertionError, msg):
self.mock.assert_awaited_with('foo')

asyncio.run(self._runnable_test('foo'))
run(self._runnable_test('foo'))
self.mock.assert_awaited_with('foo')

asyncio.run(self._runnable_test('SomethingElse'))
run(self._runnable_test('SomethingElse'))
with self.assertRaises(AssertionError):
self.mock.assert_awaited_with('foo')

def test_assert_awaited_once_with(self):
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once_with('foo')

asyncio.run(self._runnable_test('foo'))
run(self._runnable_test('foo'))
self.mock.assert_awaited_once_with('foo')

asyncio.run(self._runnable_test('foo'))
run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_awaited_once_with('foo')

def test_assert_any_wait(self):
with self.assertRaises(AssertionError):
self.mock.assert_any_await('NormalFoo')

asyncio.run(self._runnable_test('foo'))
run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_any_await('NormalFoo')

asyncio.run(self._runnable_test('NormalFoo'))
run(self._runnable_test('NormalFoo'))
self.mock.assert_any_await('NormalFoo')

asyncio.run(self._runnable_test('SomethingElse'))
run(self._runnable_test('SomethingElse'))
self.mock.assert_any_await('NormalFoo')

def test_assert_has_awaits_no_order(self):
Expand All @@ -585,42 +598,42 @@ def test_assert_has_awaits_no_order(self):
self.mock.assert_has_awaits(calls)
self.assertEqual(len(cm.exception.args), 1)

asyncio.run(self._runnable_test('foo'))
run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls)

asyncio.run(self._runnable_test('NormalFoo'))
run(self._runnable_test('NormalFoo'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls)

asyncio.run(self._runnable_test('baz'))
run(self._runnable_test('baz'))
self.mock.assert_has_awaits(calls)

asyncio.run(self._runnable_test('SomethingElse'))
run(self._runnable_test('SomethingElse'))
self.mock.assert_has_awaits(calls)

def test_assert_has_awaits_ordered(self):
calls = [call('NormalFoo'), call('baz')]
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)

asyncio.run(self._runnable_test('baz'))
run(self._runnable_test('baz'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)

asyncio.run(self._runnable_test('foo'))
run(self._runnable_test('foo'))
with self.assertRaises(AssertionError):
self.mock.assert_has_awaits(calls, any_order=True)

asyncio.run(self._runnable_test('NormalFoo'))
run(self._runnable_test('NormalFoo'))
self.mock.assert_has_awaits(calls, any_order=True)

asyncio.run(self._runnable_test('qux'))
run(self._runnable_test('qux'))
self.mock.assert_has_awaits(calls, any_order=True)

def test_assert_not_awaited(self):
self.mock.assert_not_awaited()

asyncio.run(self._runnable_test())
run(self._runnable_test())
with self.assertRaises(AssertionError):
self.mock.assert_not_awaited()

0 comments on commit 3804999

Please sign in to comment.