Synchronous class mockability

The meta-programming behind our Synchronous class doesn't play well with test
mocks. Handling this, and testing the permutations I can think of.
This commit is contained in:
Damian Johnson 2020-07-08 17:12:39 -07:00
parent 9f71ce9b21
commit ef1e41ebce
2 changed files with 88 additions and 16 deletions

View File

@ -10,9 +10,11 @@ import datetime
import functools
import inspect
import threading
import unittest.mock
from concurrent.futures import Future
from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
__all__ = [
'conf',
@ -213,19 +215,11 @@ class Synchronous(object):
# call any coroutines through this loop
def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
if Synchronous.is_asyncio_context():
return func(*args, **kwargs)
with self._loop_thread_lock:
if not self._loop_thread.is_alive():
raise RuntimeError('%s has been stopped' % type(self).__name__)
return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
if inspect.iscoroutinefunction(func):
setattr(self, method_name, functools.partial(call_async, func))
for name, func in inspect.getmembers(self):
if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
setattr(self, name, functools.partial(self._call_async_method, name))
elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
setattr(self, name, functools.partial(self._call_async_method, name))
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
@ -312,6 +306,33 @@ class Synchronous(object):
except RuntimeError:
return False
def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
"""
Run this async method from either a synchronous or asynchronous context.
:param method_name: name of the method to invoke
:param args: positional arguments
:param kwargs: keyword arguments
:returns: method's return value
:raises: **AttributeError** if this method doesn't exist
"""
# Retrieving methods by name (rather than keeping a reference) so runtime
# replacements like test mocks work.
func = getattr(type(self), method_name)
if Synchronous.is_asyncio_context():
return func(self, *args, **kwargs)
with self._loop_thread_lock:
if self._loop_thread and not self._loop_thread.is_alive():
raise RuntimeError('%s has been closed' % type(self).__name__)
return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
def __iter__(self) -> Iterator:
async def convert_generator(generator: AsyncIterator) -> Iterator:
return iter([d async for d in generator])

View File

@ -6,9 +6,10 @@ import asyncio
import io
import unittest
from unittest.mock import patch
from unittest.mock import patch, Mock
from stem.util import Synchronous
from stem.util.test_tools import coro_func_returning_value
EXAMPLE_OUTPUT = """\
hello from a synchronous context
@ -20,6 +21,8 @@ class Example(Synchronous):
async def hello(self):
return 'hello'
def sync_hello(self):
return 'hello'
class TestSynchronous(unittest.TestCase):
@patch('sys.stdout', new_callable = io.StringIO)
@ -45,7 +48,7 @@ class TestSynchronous(unittest.TestCase):
def test_ainit(self):
"""
Check that our constructor runs __ainit__ if present.
Check that our constructor runs __ainit__ when present.
"""
class AinitDemo(Synchronous):
@ -96,3 +99,51 @@ class TestSynchronous(unittest.TestCase):
instance.start()
self.assertEqual('hello', instance.hello())
instance.stop()
def test_asynchronous_mockability(self):
"""
Check that method mocks are respected.
"""
# mock prior to construction
with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
instance = Example()
self.assertEqual('mocked hello', instance.hello())
self.assertEqual('hello', instance.hello()) # mock should now be reverted
instance.stop()
# mock after construction
instance = Example()
with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
self.assertEqual('mocked hello', instance.hello())
self.assertEqual('hello', instance.hello())
instance.stop()
def test_synchronous_mockability(self):
"""
Ensure we do not disrupt non-asynchronous method mocks.
"""
# mock prior to construction
with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
instance = Example()
self.assertEqual('mocked hello', instance.sync_hello())
self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted
instance.stop()
# mock after construction
instance = Example()
with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
self.assertEqual('mocked hello', instance.sync_hello())
self.assertEqual('hello', instance.sync_hello())
instance.stop()