diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/ext/asyncio/engine.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/asyncio.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/engines.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_concurrency_py3k.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/concurrency.py | 6 |
7 files changed, 46 insertions, 18 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 16edcc2b2..93adaf78a 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -41,7 +41,7 @@ def create_async_engine(*arg, **kw): class AsyncConnectable: - __slots__ = "_slots_dispatch" + __slots__ = "_slots_dispatch", "__weakref__" @util.create_proxy_methods( diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py index 52386d33e..bdf730a4c 100644 --- a/lib/sqlalchemy/testing/asyncio.py +++ b/lib/sqlalchemy/testing/asyncio.py @@ -22,12 +22,17 @@ import inspect from . import config from ..util.concurrency import _util_async_run +from ..util.concurrency import _util_async_run_coroutine_function # may be set to False if the # --disable-asyncio flag is passed to the test runner. ENABLE_ASYNCIO = True +def _run_coroutine_function(fn, *args, **kwargs): + return _util_async_run_coroutine_function(fn, *args, **kwargs) + + def _assume_async(fn, *args, **kwargs): """Run a function in an asyncio loop unconditionally. diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index d0a1bc0d0..4d4563afb 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -97,7 +97,10 @@ class ConnectionKiller(object): self.conns = set() for rec in list(self.testing_engines): - rec.dispose() + if hasattr(rec, "sync_engine"): + rec.sync_engine.dispose() + else: + rec.dispose() def assert_all_closed(self): for rec in self.proxy_refs: @@ -236,10 +239,12 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=False): +def testing_engine(url=None, options=None, future=False, asyncio=False): """Produce an engine configured by --options with optional overrides.""" - if future or config.db and config.db._is_future: + if asyncio: + from sqlalchemy.ext.asyncio import create_async_engine as create_engine + elif future or config.db and config.db._is_future: from sqlalchemy.future import create_engine else: from sqlalchemy import create_engine @@ -263,7 +268,10 @@ def testing_engine(url=None, options=None, future=False): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index a52fdd196..0ede25176 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -48,11 +48,6 @@ class TestBase(object): # skipped. __skip_if__ = None - # If this class should be wrapped in asyncio compatibility functions - # when using an async engine. This should be set to False only for tests - # that use the asyncio features of sqlalchemy directly - __asyncio_wrap__ = True - def assert_(self, val, msg=None): assert val, msg @@ -95,12 +90,6 @@ class TestBase(object): # engines.drop_all_tables(metadata, config.db) -class AsyncTestBase(TestBase): - """Mixin marking a test as using its own explicit asyncio patterns.""" - - __asyncio_wrap__ = False - - class FutureEngineMixin(object): @classmethod def setup_class(cls): diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 6be64aa61..46468a07d 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -255,7 +255,7 @@ def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(name, obj): from sqlalchemy.testing import config - if config.any_async and getattr(obj, "__asyncio_wrap__", True): + if config.any_async: obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) @@ -277,6 +277,13 @@ def pytest_pycollect_makeitem(collector, name, obj): return [] +def _is_wrapped_coroutine_function(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + + return inspect.iscoroutinefunction(fn) + + def _apply_maybe_async(obj, recurse=True): from sqlalchemy.testing import asyncio @@ -286,6 +293,7 @@ def _apply_maybe_async(obj, recurse=True): (callable(value) or isinstance(value, classmethod)) and not getattr(value, "_maybe_async_applied", False) and (name.startswith("test_") or name in setup_names) + and not _is_wrapped_coroutine_function(value) ): is_classmethod = False if isinstance(value, classmethod): @@ -656,6 +664,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): @_pytest_fn_decorator def decorate(fn, *args, **kwargs): - asyncio._assume_async(fn, *args, **kwargs) + asyncio._run_coroutine_function(fn, *args, **kwargs) return decorate(fn) diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 6042e4395..663d3e0f4 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -136,6 +136,18 @@ class AsyncAdaptedLock: self.mutex.release() +def _util_async_run_coroutine_function(fn, *args, **kwargs): + """for test suite/ util only""" + + loop = asyncio.get_event_loop() + if loop.is_running(): + raise Exception( + "for async run coroutine we expect that no greenlet or event " + "loop is running when we start out" + ) + return loop.run_until_complete(fn(*args, **kwargs)) + + def _util_async_run(fn, *args, **kwargs): """for test suite/ util only""" diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 7b4ff6ba4..c44efba62 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -14,6 +14,9 @@ if compat.py3k: from ._concurrency_py3k import greenlet_spawn from ._concurrency_py3k import AsyncAdaptedLock from ._concurrency_py3k import _util_async_run # noqa F401 + from ._concurrency_py3k import ( + _util_async_run_coroutine_function, + ) # noqa F401, E501 from ._concurrency_py3k import asyncio # noqa F401 if not have_greenlet: @@ -42,3 +45,6 @@ if not have_greenlet: def _util_async_run(fn, *arg, **kw): # noqa F81 return fn(*arg, **kw) + + def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa F81 + _not_implemented() |
