diff options
Diffstat (limited to 'lib/sqlalchemy/testing/plugin/pytestplugin.py')
| -rw-r--r-- | lib/sqlalchemy/testing/plugin/pytestplugin.py | 55 | 
1 files changed, 55 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 015598952..3df239afa 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -26,6 +26,11 @@ else:          from typing import Sequence  try: +    import asyncio +except ImportError: +    pass + +try:      import xdist  # noqa      has_xdist = True @@ -101,6 +106,24 @@ def pytest_configure(config):      plugin_base.set_fixture_functions(PytestFixtureFunctions) +    if config.option.dump_pyannotate: +        global DUMP_PYANNOTATE +        DUMP_PYANNOTATE = True + + +DUMP_PYANNOTATE = False + + +@pytest.fixture(autouse=True) +def collect_types_fixture(): +    if DUMP_PYANNOTATE: +        from pyannotate_runtime import collect_types + +        collect_types.start() +    yield +    if DUMP_PYANNOTATE: +        collect_types.stop() +  def pytest_sessionstart(session):      plugin_base.post_begin() @@ -109,6 +132,31 @@ def pytest_sessionstart(session):  def pytest_sessionfinish(session):      plugin_base.final_process_cleanup() +    if session.config.option.dump_pyannotate: +        from pyannotate_runtime import collect_types + +        collect_types.dump_stats(session.config.option.dump_pyannotate) + + +def pytest_collection_finish(session): +    if session.config.option.dump_pyannotate: +        from pyannotate_runtime import collect_types + +        lib_sqlalchemy = os.path.abspath("lib/sqlalchemy") + +        def _filter(filename): +            filename = os.path.normpath(os.path.abspath(filename)) +            if "lib/sqlalchemy" not in os.path.commonpath( +                [filename, lib_sqlalchemy] +            ): +                return None +            if "testing" in filename: +                return None + +            return filename + +        collect_types.init_types_collection(filter_filename=_filter) +  if has_xdist:      import uuid @@ -518,3 +566,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):      def get_current_test_name(self):          return os.environ.get("PYTEST_CURRENT_TEST") + +    def async_test(self, fn): +        @_pytest_fn_decorator +        def decorate(fn, *args, **kwargs): +            asyncio.get_event_loop().run_until_complete(fn(*args, **kwargs)) + +        return decorate(fn)  | 
