mirror of
https://github.com/torproject/stem.git
synced 2024-12-04 00:00:46 +00:00
Synchronous class mockability
The meta-programming behind our Synchronous class doesn't play well with test mocks. Handling this, and testing the permutations I can think of.
This commit is contained in:
parent
9f71ce9b21
commit
ef1e41ebce
@ -10,9 +10,11 @@ import datetime
|
||||
import functools
|
||||
import inspect
|
||||
import threading
|
||||
import unittest.mock
|
||||
|
||||
from concurrent.futures import Future
|
||||
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
|
||||
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
|
||||
|
||||
__all__ = [
|
||||
'conf',
|
||||
@ -213,19 +215,11 @@ class Synchronous(object):
|
||||
|
||||
# call any coroutines through this loop
|
||||
|
||||
def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
if Synchronous.is_asyncio_context():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with self._loop_thread_lock:
|
||||
if not self._loop_thread.is_alive():
|
||||
raise RuntimeError('%s has been stopped' % type(self).__name__)
|
||||
|
||||
return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
|
||||
|
||||
for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
setattr(self, method_name, functools.partial(call_async, func))
|
||||
for name, func in inspect.getmembers(self):
|
||||
if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
|
||||
setattr(self, name, functools.partial(self._call_async_method, name))
|
||||
elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
|
||||
setattr(self, name, functools.partial(self._call_async_method, name))
|
||||
|
||||
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
|
||||
|
||||
@ -312,6 +306,33 @@ class Synchronous(object):
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Run this async method from either a synchronous or asynchronous context.
|
||||
|
||||
:param method_name: name of the method to invoke
|
||||
:param args: positional arguments
|
||||
:param kwargs: keyword arguments
|
||||
|
||||
:returns: method's return value
|
||||
|
||||
:raises: **AttributeError** if this method doesn't exist
|
||||
"""
|
||||
|
||||
# Retrieving methods by name (rather than keeping a reference) so runtime
|
||||
# replacements like test mocks work.
|
||||
|
||||
func = getattr(type(self), method_name)
|
||||
|
||||
if Synchronous.is_asyncio_context():
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
with self._loop_thread_lock:
|
||||
if self._loop_thread and not self._loop_thread.is_alive():
|
||||
raise RuntimeError('%s has been closed' % type(self).__name__)
|
||||
|
||||
return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
async def convert_generator(generator: AsyncIterator) -> Iterator:
|
||||
return iter([d async for d in generator])
|
||||
|
@ -6,9 +6,10 @@ import asyncio
|
||||
import io
|
||||
import unittest
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from stem.util import Synchronous
|
||||
from stem.util.test_tools import coro_func_returning_value
|
||||
|
||||
EXAMPLE_OUTPUT = """\
|
||||
hello from a synchronous context
|
||||
@ -20,6 +21,8 @@ class Example(Synchronous):
|
||||
async def hello(self):
|
||||
return 'hello'
|
||||
|
||||
def sync_hello(self):
|
||||
return 'hello'
|
||||
|
||||
class TestSynchronous(unittest.TestCase):
|
||||
@patch('sys.stdout', new_callable = io.StringIO)
|
||||
@ -45,7 +48,7 @@ class TestSynchronous(unittest.TestCase):
|
||||
|
||||
def test_ainit(self):
|
||||
"""
|
||||
Check that our constructor runs __ainit__ if present.
|
||||
Check that our constructor runs __ainit__ when present.
|
||||
"""
|
||||
|
||||
class AinitDemo(Synchronous):
|
||||
@ -96,3 +99,51 @@ class TestSynchronous(unittest.TestCase):
|
||||
instance.start()
|
||||
self.assertEqual('hello', instance.hello())
|
||||
instance.stop()
|
||||
|
||||
def test_asynchronous_mockability(self):
|
||||
"""
|
||||
Check that method mocks are respected.
|
||||
"""
|
||||
|
||||
# mock prior to construction
|
||||
|
||||
with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
|
||||
instance = Example()
|
||||
self.assertEqual('mocked hello', instance.hello())
|
||||
|
||||
self.assertEqual('hello', instance.hello()) # mock should now be reverted
|
||||
instance.stop()
|
||||
|
||||
# mock after construction
|
||||
|
||||
instance = Example()
|
||||
|
||||
with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
|
||||
self.assertEqual('mocked hello', instance.hello())
|
||||
|
||||
self.assertEqual('hello', instance.hello())
|
||||
instance.stop()
|
||||
|
||||
def test_synchronous_mockability(self):
|
||||
"""
|
||||
Ensure we do not disrupt non-asynchronous method mocks.
|
||||
"""
|
||||
|
||||
# mock prior to construction
|
||||
|
||||
with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
|
||||
instance = Example()
|
||||
self.assertEqual('mocked hello', instance.sync_hello())
|
||||
|
||||
self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted
|
||||
instance.stop()
|
||||
|
||||
# mock after construction
|
||||
|
||||
instance = Example()
|
||||
|
||||
with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
|
||||
self.assertEqual('mocked hello', instance.sync_hello())
|
||||
|
||||
self.assertEqual('hello', instance.sync_hello())
|
||||
instance.stop()
|
||||
|
Loading…
Reference in New Issue
Block a user