Source code for mode.utils.mocks

"""Mocking and testing utilities."""
import asyncio
import builtins
import sys
import types
import unittest.mock
from asyncio import coroutine
from contextlib import contextmanager
from itertools import count
from typing import Any, Callable, ContextManager, List, Optional, Type, Union

__all__ = [

MOCK_CALL_COUNT = count(0)

[docs]class IN: """Class used to check for multiple alternatives. .. sourcecode:: python assert foo.value IN(a, b) """ def __init__(self, *alternatives): self.alternatives = alternatives def __eq__(self, other: Any) -> bool: return other in self.alternatives def __ne__(self, other: Any) -> bool: return other not in self.alternatives def __repr__(self) -> str: sep = ' | ' return f'<IN: {sep.join(map(str, self.alternatives))}>'
[docs]class Mock(unittest.mock.Mock): """Mock object.""" global_call_count: Optional[int] = None call_counts: List[int] = None def __call__(self, *args, **kwargs): ret = super().__call__(*args, **kwargs) count = self.global_call_count = next(MOCK_CALL_COUNT) if self.call_counts is None: self.call_counts = [count] else: self.call_counts.append(count) return ret
[docs] def reset_mock(self, *args, **kwargs): super().reset_mock(*args, **kwargs) if self.call_counts is not None: self.call_counts.clear()
class _ContextMock(Mock, ContextManager): """Internal context mock class. Dummy class implementing __enter__ and __exit__ as the :keyword:`with` statement requires these to be implemented in the class, not just the instance. """ def __enter__(self) -> '_ContextMock': return self def __exit__(self, exc_type: Type[BaseException] = None, exc_val: BaseException = None, exc_tb: types.TracebackType = None) -> Optional[bool]: pass
[docs]def ContextMock(*args: Any, **kwargs: Any) -> _ContextMock: """Mock that mocks :keyword:`with` statement contexts.""" obj = _ContextMock(*args, **kwargs) obj.attach_mock(_ContextMock(), '__enter__') obj.attach_mock(_ContextMock(), '__exit__') obj.__enter__.return_value = obj # if __exit__ return a value the exception is ignored, # so it must return None here. obj.__exit__.return_value = None return obj
[docs]class AsyncMock(unittest.mock.Mock): """Mock for ``async def`` function/method or anything awaitable.""" def __init__(self, *args: Any, name: str = None, **kwargs: Any) -> None: super().__init__(name=name) coro = Mock(*args, **kwargs) self.attach_mock(coro, 'coro') self.side_effect = coroutine(coro)
[docs]class AsyncMagicMock(unittest.mock.MagicMock): """A magic mock type for ``async def`` functions/methods.""" def __init__(self, *args: Any, name: str = None, **kwargs: Any) -> None: super().__init__(name=name) coro = MagicMock(*args, **kwargs) self.attach_mock(coro, 'coro') self.side_effect = coroutine(coro)
[docs]class AsyncContextMock(unittest.mock.Mock): """Mock for :class:`typing.AsyncContextManager`. You can use this to mock asynchronous context managers, when an object with a fully defined ``__aenter__`` and ``__aexit__`` is required. Here's an example mocking an :pypi:`aiohttp` client: .. code-block:: python import http from aiohttp.client import ClientSession from aiohttp.web import Response from mode.utils.mocks import AsyncContextManagerMock, AsyncMock, Mock @pytest.fixture() def session(monkeypatch): session = Mock( name='http_client', autospec=ClientSession, request=Mock( return_value=AsyncContextManagerMock( return_value=Mock( autospec=Response, status=http.HTTPStatus.OK, json=AsyncMock( return_value={'hello': 'json'}, ), ), ), ), ) monkeypatch.setattr('', session) return session @pytest.mark.asyncio async def test_session(session): from import ClientSession session = ClientSession() async with session.get('') as response: assert response.status == http.HTTPStatus.OK assert await response.json() == {'hello': 'json'} """ def __init__(self, *args: Any, aenter_return: Any = None, aexit_return: Any = None, side_effect: Union[Callable, BaseException] = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.aenter_return = aenter_return self.aexit_return = aexit_return self.side_effect = side_effect async def __aenter__(self) -> Any: mgr = self.aenter_return or self.return_value if self.side_effect: if isinstance(self.side_effect, BaseException): raise self.side_effect else: return self.side_effect() if isinstance(mgr, AsyncMock): return mgr.coro return mgr async def __aexit__(self, *args: Any) -> Any: return self.aexit_return
AsyncContextManagerMock = AsyncContextMock # XXX compat alias
[docs]class FutureMock(unittest.mock.Mock): """Mock a :class:`asyncio.Future`.""" awaited = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._loop = asyncio.get_event_loop() def __await__(self): self.awaited = True yield self()
[docs] def assert_awaited(self): assert self.awaited
[docs] def assert_not_awaited(self): assert not self.awaited
[docs]@contextmanager def patch_module(*names: str, new_callable: Any = Mock): """Mock one or modules such that every attribute is a :class:`Mock`.""" prev = {} class MockModule(types.ModuleType): def __getattr__(self, attr): setattr(self, attr, new_callable()) return types.ModuleType.__getattribute__(self, attr) mods = [] for name in names: try: prev[name] = sys.modules[name] except KeyError: pass mod = sys.modules[name] = MockModule(name) mods.append(mod) try: yield mods finally: for name in names: try: sys.modules[name] = prev[name] except KeyError: try: del(sys.modules[name]) except KeyError: pass
[docs]@contextmanager def mask_module(*modnames): """Ban some modules from being importable inside the context. For example:: >>> with mask_module('sys'): ... try: ... import sys ... except ImportError: ... print('sys not found') sys not found >>> import sys # noqa >>> sys.version (2, 5, 2, 'final', 0) Taken from """ realimport = builtins.__import__ def myimp(name, *args, **kwargs): if name in modnames: raise ImportError(f'No module named {name}') else: return realimport(name, *args, **kwargs) builtins.__import__ = myimp try: yield finally: builtins.__import__ = realimport
ANY = unittest.mock.ANY MagicMock = unittest.mock.MagicMock call = patch = unittest.mock.patch