mirror of
https://github.com/torproject/stem.git
synced 2024-12-05 00:46:41 +00:00
Synchronous class mockability
The meta-programming behind our Synchronous class doesn't play well with test mocks. Handling this, and testing the permutations I can think of.
This commit is contained in:
parent
9f71ce9b21
commit
ef1e41ebce
@ -10,9 +10,11 @@ import datetime
|
|||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import threading
|
import threading
|
||||||
|
import unittest.mock
|
||||||
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
|
||||||
from typing import Any, AsyncIterator, Callable, Iterator, Optional, Type, Union
|
from typing import Any, AsyncIterator, Iterator, Optional, Type, Union
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'conf',
|
'conf',
|
||||||
@ -213,19 +215,11 @@ class Synchronous(object):
|
|||||||
|
|
||||||
# call any coroutines through this loop
|
# call any coroutines through this loop
|
||||||
|
|
||||||
def call_async(func: Callable, *args: Any, **kwargs: Any) -> Any:
|
for name, func in inspect.getmembers(self):
|
||||||
if Synchronous.is_asyncio_context():
|
if isinstance(func, unittest.mock.Mock) and inspect.iscoroutinefunction(func.side_effect):
|
||||||
return func(*args, **kwargs)
|
setattr(self, name, functools.partial(self._call_async_method, name))
|
||||||
|
elif inspect.ismethod(func) and inspect.iscoroutinefunction(func):
|
||||||
with self._loop_thread_lock:
|
setattr(self, name, functools.partial(self._call_async_method, name))
|
||||||
if not self._loop_thread.is_alive():
|
|
||||||
raise RuntimeError('%s has been stopped' % type(self).__name__)
|
|
||||||
|
|
||||||
return asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self._loop).result()
|
|
||||||
|
|
||||||
for method_name, func in inspect.getmembers(self, predicate = inspect.ismethod):
|
|
||||||
if inspect.iscoroutinefunction(func):
|
|
||||||
setattr(self, method_name, functools.partial(call_async, func))
|
|
||||||
|
|
||||||
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
|
asyncio.run_coroutine_threadsafe(asyncio.coroutine(self.__ainit__)(), self._loop).result()
|
||||||
|
|
||||||
@ -312,6 +306,33 @@ class Synchronous(object):
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _call_async_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Run this async method from either a synchronous or asynchronous context.
|
||||||
|
|
||||||
|
:param method_name: name of the method to invoke
|
||||||
|
:param args: positional arguments
|
||||||
|
:param kwargs: keyword arguments
|
||||||
|
|
||||||
|
:returns: method's return value
|
||||||
|
|
||||||
|
:raises: **AttributeError** if this method doesn't exist
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Retrieving methods by name (rather than keeping a reference) so runtime
|
||||||
|
# replacements like test mocks work.
|
||||||
|
|
||||||
|
func = getattr(type(self), method_name)
|
||||||
|
|
||||||
|
if Synchronous.is_asyncio_context():
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
with self._loop_thread_lock:
|
||||||
|
if self._loop_thread and not self._loop_thread.is_alive():
|
||||||
|
raise RuntimeError('%s has been closed' % type(self).__name__)
|
||||||
|
|
||||||
|
return asyncio.run_coroutine_threadsafe(func(self, *args, **kwargs), self._loop).result()
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
def __iter__(self) -> Iterator:
|
||||||
async def convert_generator(generator: AsyncIterator) -> Iterator:
|
async def convert_generator(generator: AsyncIterator) -> Iterator:
|
||||||
return iter([d async for d in generator])
|
return iter([d async for d in generator])
|
||||||
|
@ -6,9 +6,10 @@ import asyncio
|
|||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
from stem.util import Synchronous
|
from stem.util import Synchronous
|
||||||
|
from stem.util.test_tools import coro_func_returning_value
|
||||||
|
|
||||||
EXAMPLE_OUTPUT = """\
|
EXAMPLE_OUTPUT = """\
|
||||||
hello from a synchronous context
|
hello from a synchronous context
|
||||||
@ -20,6 +21,8 @@ class Example(Synchronous):
|
|||||||
async def hello(self):
|
async def hello(self):
|
||||||
return 'hello'
|
return 'hello'
|
||||||
|
|
||||||
|
def sync_hello(self):
|
||||||
|
return 'hello'
|
||||||
|
|
||||||
class TestSynchronous(unittest.TestCase):
|
class TestSynchronous(unittest.TestCase):
|
||||||
@patch('sys.stdout', new_callable = io.StringIO)
|
@patch('sys.stdout', new_callable = io.StringIO)
|
||||||
@ -45,7 +48,7 @@ class TestSynchronous(unittest.TestCase):
|
|||||||
|
|
||||||
def test_ainit(self):
|
def test_ainit(self):
|
||||||
"""
|
"""
|
||||||
Check that our constructor runs __ainit__ if present.
|
Check that our constructor runs __ainit__ when present.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class AinitDemo(Synchronous):
|
class AinitDemo(Synchronous):
|
||||||
@ -96,3 +99,51 @@ class TestSynchronous(unittest.TestCase):
|
|||||||
instance.start()
|
instance.start()
|
||||||
self.assertEqual('hello', instance.hello())
|
self.assertEqual('hello', instance.hello())
|
||||||
instance.stop()
|
instance.stop()
|
||||||
|
|
||||||
|
def test_asynchronous_mockability(self):
|
||||||
|
"""
|
||||||
|
Check that method mocks are respected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mock prior to construction
|
||||||
|
|
||||||
|
with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
|
||||||
|
instance = Example()
|
||||||
|
self.assertEqual('mocked hello', instance.hello())
|
||||||
|
|
||||||
|
self.assertEqual('hello', instance.hello()) # mock should now be reverted
|
||||||
|
instance.stop()
|
||||||
|
|
||||||
|
# mock after construction
|
||||||
|
|
||||||
|
instance = Example()
|
||||||
|
|
||||||
|
with patch('test.unit.util.synchronous.Example.hello', Mock(side_effect = coro_func_returning_value('mocked hello'))):
|
||||||
|
self.assertEqual('mocked hello', instance.hello())
|
||||||
|
|
||||||
|
self.assertEqual('hello', instance.hello())
|
||||||
|
instance.stop()
|
||||||
|
|
||||||
|
def test_synchronous_mockability(self):
|
||||||
|
"""
|
||||||
|
Ensure we do not disrupt non-asynchronous method mocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mock prior to construction
|
||||||
|
|
||||||
|
with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
|
||||||
|
instance = Example()
|
||||||
|
self.assertEqual('mocked hello', instance.sync_hello())
|
||||||
|
|
||||||
|
self.assertEqual('hello', instance.sync_hello()) # mock should now be reverted
|
||||||
|
instance.stop()
|
||||||
|
|
||||||
|
# mock after construction
|
||||||
|
|
||||||
|
instance = Example()
|
||||||
|
|
||||||
|
with patch('test.unit.util.synchronous.Example.sync_hello', Mock(return_value = 'mocked hello')):
|
||||||
|
self.assertEqual('mocked hello', instance.sync_hello())
|
||||||
|
|
||||||
|
self.assertEqual('hello', instance.sync_hello())
|
||||||
|
instance.stop()
|
||||||
|
Loading…
Reference in New Issue
Block a user