summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/plugin
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-01-06 01:14:26 -0500
committermike bayer <mike_mp@zzzcomputing.com>2019-01-06 17:34:50 +0000
commit1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch)
tree28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/testing/plugin
parent404e69426b05a82d905cbb3ad33adafccddb00dd (diff)
downloadsqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits applied at all. The black run will format code consistently, however in some cases that are prevalent in SQLAlchemy code it produces too-long lines. The too-long lines will be resolved in the following commit that will resolve all remaining flake8 issues including shadowed builtins, long lines, import order, unused imports, duplicate imports, and docstring issues. Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/testing/plugin')
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py7
-rw-r--r--lib/sqlalchemy/testing/plugin/noseplugin.py23
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py359
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py104
4 files changed, 310 insertions, 183 deletions
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
index 497fcb7e5..bb52c125c 100644
--- a/lib/sqlalchemy/testing/plugin/bootstrap.py
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0.
import os
import sys
-bootstrap_file = locals()['bootstrap_file']
-to_bootstrap = locals()['to_bootstrap']
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
if sys.version_info >= (3, 3):
from importlib import machinery
+
mod = machinery.SourceFileLoader(name, path).load_module()
else:
import imp
+
mod = imp.load_source(name, path)
return mod
+
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py
index 20ea61d89..0c28a5213 100644
--- a/lib/sqlalchemy/testing/plugin/noseplugin.py
+++ b/lib/sqlalchemy/testing/plugin/noseplugin.py
@@ -25,6 +25,7 @@ import sys
from nose.plugins import Plugin
import nose
+
fixtures = None
py3k = sys.version_info >= (3, 0)
@@ -33,7 +34,7 @@ py3k = sys.version_info >= (3, 0)
class NoseSQLAlchemy(Plugin):
enabled = True
- name = 'sqla_testing'
+ name = "sqla_testing"
score = 100
def options(self, parser, env=os.environ):
@@ -41,10 +42,14 @@ class NoseSQLAlchemy(Plugin):
opt = parser.add_option
def make_option(name, **kw):
- callback_ = kw.pop("callback", None) or kw.pop("zeroarg_callback", None)
+ callback_ = kw.pop("callback", None) or kw.pop(
+ "zeroarg_callback", None
+ )
if callback_:
+
def wrap_(option, opt_str, value, parser):
callback_(opt_str, value, parser)
+
kw["callback"] = wrap_
opt(name, **kw)
@@ -73,7 +78,7 @@ class NoseSQLAlchemy(Plugin):
def wantMethod(self, fn):
if py3k:
- if not hasattr(fn.__self__, 'cls'):
+ if not hasattr(fn.__self__, "cls"):
return False
cls = fn.__self__.cls
else:
@@ -84,24 +89,24 @@ class NoseSQLAlchemy(Plugin):
return plugin_base.want_class(cls)
def beforeTest(self, test):
- if not hasattr(test.test, 'cls'):
+ if not hasattr(test.test, "cls"):
return
plugin_base.before_test(
test,
test.test.cls.__module__,
- test.test.cls, test.test.method.__name__)
+ test.test.cls,
+ test.test.method.__name__,
+ )
def afterTest(self, test):
plugin_base.after_test(test)
def startContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.start_test_class(ctx)
def stopContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.stop_test_class(ctx)
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 0ffcae093..5d6bf2975 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -46,58 +46,130 @@ options = None
def setup_options(make_option):
- make_option("--log-info", action="callback", type="string", callback=_log,
- help="turn on info logging for <LOG> (multiple OK)")
- make_option("--log-debug", action="callback",
- type="string", callback=_log,
- help="turn on debug logging for <LOG> (multiple OK)")
- make_option("--db", action="append", type="string", dest="db",
- help="Use prefab database uri. Multiple OK, "
- "first one is run by default.")
- make_option('--dbs', action='callback', zeroarg_callback=_list_dbs,
- help="List available prefab dbs")
- make_option("--dburi", action="append", type="string", dest="dburi",
- help="Database uri. Multiple OK, "
- "first one is run by default.")
- make_option("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first")
- make_option("--backend-only", action="store_true", dest="backend_only",
- help="Run only tests marked with __backend__")
- make_option("--nomemory", action="store_true", dest="nomemory",
- help="Don't run memory profiling tests")
- make_option("--postgresql-templatedb", type="string",
- help="name of template database to use for PostgreSQL "
- "CREATE DATABASE (defaults to current database)")
- make_option("--low-connections", action="store_true",
- dest="low_connections",
- help="Use a low number of distinct connections - "
- "i.e. for Oracle TNS")
- make_option("--write-idents", type="string", dest="write_idents",
- help="write out generated follower idents to <file>, "
- "when -n<num> is used")
- make_option("--reversetop", action="store_true",
- dest="reversetop", default=False,
- help="Use a random-ordering set implementation in the ORM "
- "(helps reveal dependency issues)")
- make_option("--requirements", action="callback", type="string",
- callback=_requirements_opt,
- help="requirements class for testing, overrides setup.cfg")
- make_option("--with-cdecimal", action="store_true",
- dest="cdecimal", default=False,
- help="Monkeypatch the cdecimal library into Python 'decimal' "
- "for all tests")
- make_option("--include-tag", action="callback", callback=_include_tag,
- type="string",
- help="Include tests with tag <tag>")
- make_option("--exclude-tag", action="callback", callback=_exclude_tag,
- type="string",
- help="Exclude tests with tag <tag>")
- make_option("--write-profiles", action="store_true",
- dest="write_profiles", default=False,
- help="Write/update failing profiling data.")
- make_option("--force-write-profiles", action="store_true",
- dest="force_write_profiles", default=False,
- help="Unconditionally write/update profiling data.")
+ make_option(
+ "--log-info",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type="string",
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type="string",
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__",
+ )
+ make_option(
+ "--nomemory",
+ action="store_true",
+ dest="nomemory",
+ help="Don't run memory profiling tests",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type="string",
+ help="name of template database to use for PostgreSQL "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type="string",
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type="string",
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type="string",
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type="string",
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--write-profiles",
+ action="store_true",
+ dest="write_profiles",
+ default=False,
+ help="Write/update failing profiling data.",
+ )
+ make_option(
+ "--force-write-profiles",
+ action="store_true",
+ dest="force_write_profiles",
+ default=False,
+ help="Unconditionally write/update profiling data.",
+ )
def configure_follower(follower_ident):
@@ -108,6 +180,7 @@ def configure_follower(follower_ident):
"""
from sqlalchemy.testing import provision
+
provision.FOLLOWER_IDENT = follower_ident
@@ -121,9 +194,9 @@ def memoize_important_follower_config(dict_):
callables, so we have to just copy all of that over.
"""
- dict_['memoized_config'] = {
- 'include_tags': include_tags,
- 'exclude_tags': exclude_tags
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
}
@@ -134,14 +207,14 @@ def restore_important_follower_config(dict_):
"""
global include_tags, exclude_tags
- include_tags.update(dict_['memoized_config']['include_tags'])
- exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
def read_config():
global file_config
file_config = configparser.ConfigParser()
- file_config.read(['setup.cfg', 'test.cfg'])
+ file_config.read(["setup.cfg", "test.cfg"])
def pre_begin(opt):
@@ -155,6 +228,7 @@ def pre_begin(opt):
def set_coverage_flag(value):
options.has_coverage = value
+
_skip_test_exception = None
@@ -171,34 +245,33 @@ def post_begin():
# late imports, has to happen after config as well
# as nose plugins like coverage
- global util, fixtures, engines, exclusions, \
- assertions, warnings, profiling,\
- config, testing
- from sqlalchemy import testing # noqa
+ global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing
+ from sqlalchemy import testing # noqa
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
- from sqlalchemy.testing import assertions, warnings, profiling # noqa
+ from sqlalchemy.testing import assertions, warnings, profiling # noqa
from sqlalchemy.testing import config # noqa
from sqlalchemy import util # noqa
- warnings.setup_filters()
+ warnings.setup_filters()
def _log(opt_str, value, parser):
global logging
if not logging:
import logging
+
logging.basicConfig()
- if opt_str.endswith('-info'):
+ if opt_str.endswith("-info"):
logging.getLogger(value).setLevel(logging.INFO)
- elif opt_str.endswith('-debug'):
+ elif opt_str.endswith("-debug"):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
print("Available --db options (use --dburi to override)")
- for macro in sorted(file_config.options('db')):
- print("%20s\t%s" % (macro, file_config.get('db', macro)))
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
sys.exit(0)
@@ -207,11 +280,12 @@ def _requirements_opt(opt_str, value, parser):
def _exclude_tag(opt_str, value, parser):
- exclude_tags.add(value.replace('-', '_'))
+ exclude_tags.add(value.replace("-", "_"))
def _include_tag(opt_str, value, parser):
- include_tags.add(value.replace('-', '_'))
+ include_tags.add(value.replace("-", "_"))
+
pre_configure = []
post_configure = []
@@ -243,7 +317,8 @@ def _set_nomemory(opt, file_config):
def _monkeypatch_cdecimal(options, file_config):
if options.cdecimal:
import cdecimal
- sys.modules['decimal'] = cdecimal
+
+ sys.modules["decimal"] = cdecimal
@post
@@ -266,27 +341,28 @@ def _engine_uri(options, file_config):
if options.db:
for db_token in options.db:
- for db in re.split(r'[,\s]+', db_token):
- if db not in file_config.options('db'):
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
raise RuntimeError(
"Unknown URI specifier '%s'. "
- "Specify --dbs for known uris."
- % db)
+ "Specify --dbs for known uris." % db
+ )
else:
- db_urls.append(file_config.get('db', db))
+ db_urls.append(file_config.get("db", db))
if not db_urls:
- db_urls.append(file_config.get('db', 'default'))
+ db_urls.append(file_config.get("db", "default"))
config._current = None
for db_url in db_urls:
- if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
+ if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
with open(options.write_idents, "a") as file_:
file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
cfg = provision.setup_config(
- db_url, options, file_config, provision.FOLLOWER_IDENT)
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
if not config._current:
cfg.set_as_current(cfg, testing)
@@ -295,7 +371,7 @@ def _engine_uri(options, file_config):
@post
def _requirements(options, file_config):
- requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
@@ -334,22 +410,28 @@ def _prep_testing_database(options, file_config):
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData())
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(vname, schema.MetaData())
+ )
+ )
if config.requirements.schemas.enabled_for_config(cfg):
try:
- view_names = inspector.get_view_names(
- schema="test_schema")
+ view_names = inspector.get_view_names(schema="test_schema")
except NotImplementedError:
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData(),
- schema="test_schema")
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
util.drop_all_tables(e, inspector)
@@ -358,23 +440,29 @@ def _prep_testing_database(options, file_config):
if against(cfg, "postgresql"):
from sqlalchemy.dialects import postgresql
+
for enum in inspector.get_enums("*"):
- e.execute(postgresql.DropEnumType(
- postgresql.ENUM(
- name=enum['name'],
- schema=enum['schema'])))
+ e.execute(
+ postgresql.DropEnumType(
+ postgresql.ENUM(
+ name=enum["name"], schema=enum["schema"]
+ )
+ )
+ )
@post
def _reverse_topological(options, file_config):
if options.reversetop:
from sqlalchemy.orm.util import randomize_unitofwork
+
randomize_unitofwork()
@post
def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
+
config.options = options
config.file_config = file_config
@@ -382,17 +470,20 @@ def _post_setup_options(opt, file_config):
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
+
profiling._profile_stats = profiling.ProfileStatsFile(
- file_config.get('sqla_testing', 'profile_file'))
+ file_config.get("sqla_testing", "profile_file")
+ )
def want_class(cls):
if not issubclass(cls, fixtures.TestBase):
return False
- elif cls.__name__.startswith('_'):
+ elif cls.__name__.startswith("_"):
return False
- elif config.options.backend_only and not getattr(cls, '__backend__',
- False):
+ elif config.options.backend_only and not getattr(
+ cls, "__backend__", False
+ ):
return False
else:
return True
@@ -405,25 +496,28 @@ def want_method(cls, fn):
return False
elif include_tags:
return (
- hasattr(cls, '__tags__') and
- exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
) or (
- hasattr(fn, '_sa_exclusion_extend') and
- fn._sa_exclusion_extend.include_test(
- include_tags, exclude_tags)
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
)
- elif exclude_tags and hasattr(cls, '__tags__'):
+ elif exclude_tags and hasattr(cls, "__tags__"):
return exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
- elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
else:
return True
def generate_sub_tests(cls, module):
- if getattr(cls, '__backend__', False):
+ if getattr(cls, "__backend__", False):
for cfg in _possible_configs_for_cls(cls):
orig_name = cls.__name__
@@ -431,16 +525,13 @@ def generate_sub_tests(cls, module):
# pytest junit plugin, which is tripped up by the brackets
# and periods, so sanitize
- alpha_name = re.sub('[_\[\]\.]+', '_', cfg.name)
- alpha_name = re.sub('_+$', '', alpha_name)
+ alpha_name = re.sub("[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub("_+$", "", alpha_name)
name = "%s_%s" % (cls.__name__, alpha_name)
subcls = type(
name,
- (cls, ),
- {
- "_sa_orig_cls_name": orig_name,
- "__only_on_config__": cfg
- }
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
)
setattr(module, name, subcls)
yield subcls
@@ -454,8 +545,8 @@ def start_test_class(cls):
def stop_test_class(cls):
- #from sqlalchemy import inspect
- #assert not inspect(testing.db).get_table_names()
+ # from sqlalchemy import inspect
+ # assert not inspect(testing.db).get_table_names()
engines.testing_reaper._stop_test_ctx()
try:
if not options.low_connections:
@@ -475,7 +566,7 @@ def final_process_cleanup():
def _setup_engine(cls):
- if getattr(cls, '__engine_options__', None):
+ if getattr(cls, "__engine_options__", None):
eng = engines.testing_engine(options=cls.__engine_options__)
config._current.push_engine(eng, testing)
@@ -485,7 +576,7 @@ def before_test(test, test_module_name, test_class, test_name):
# like a nose id, e.g.:
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
- name = getattr(test_class, '_sa_orig_cls_name', test_class.__name__)
+ name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
@@ -505,16 +596,16 @@ def _possible_configs_for_cls(cls, reasons=None):
if spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on__', None):
+ if getattr(cls, "__only_on__", None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on_config__', None):
+ if getattr(cls, "__only_on_config__", None):
all_configs.intersection_update([cls.__only_on_config__])
- if hasattr(cls, '__requires__'):
+ if hasattr(cls, "__requires__"):
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__requires__:
@@ -527,7 +618,7 @@ def _possible_configs_for_cls(cls, reasons=None):
reasons.extend(skip_reasons)
break
- if hasattr(cls, '__prefer_requires__'):
+ if hasattr(cls, "__prefer_requires__"):
non_preferred = set()
requirements = config.requirements
for config_obj in list(all_configs):
@@ -546,30 +637,32 @@ def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
- if getattr(cls, '__skip_if__', False):
- for c in getattr(cls, '__skip_if__'):
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
if c():
- config.skip_test("'%s' skipped by %s" % (
- cls.__name__, c.__name__)
+ config.skip_test(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
)
if not all_configs:
msg = "'%s' unsupported on any DB implementation %s%s" % (
cls.__name__,
", ".join(
- "'%s(%s)+%s'" % (
+ "'%s(%s)+%s'"
+ % (
config_obj.db.name,
".".join(
- str(dig) for dig in
- exclusions._server_version(config_obj.db)),
- config_obj.db.driver
+ str(dig)
+ for dig in exclusions._server_version(config_obj.db)
+ ),
+ config_obj.db.driver,
)
- for config_obj in config.Config.all_configs()
+ for config_obj in config.Config.all_configs()
),
- ", ".join(reasons)
+ ", ".join(reasons),
)
config.skip_test(msg)
- elif hasattr(cls, '__prefer_backends__'):
+ elif hasattr(cls, "__prefer_backends__"):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
for config_obj in all_configs:
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index da682ea00..fd0a48462 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -13,6 +13,7 @@ import os
try:
import xdist # noqa
+
has_xdist = True
except ImportError:
has_xdist = False
@@ -24,30 +25,42 @@ def pytest_addoption(parser):
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
+
class CallableAction(argparse.Action):
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
callback_(option_string, values, parser)
+
kw["action"] = CallableAction
zeroarg_callback = kw.pop("zeroarg_callback", None)
if zeroarg_callback:
+
class CallableAction(argparse.Action):
- def __init__(self, option_strings,
- dest, default=False,
- required=False, help=None):
- super(CallableAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- const=True,
- default=default,
- required=required,
- help=help)
-
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None,
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
zeroarg_callback(option_string, values, parser)
+
kw["action"] = CallableAction
group.addoption(name, **kw)
@@ -59,18 +72,18 @@ def pytest_addoption(parser):
def pytest_configure(config):
if hasattr(config, "slaveinput"):
plugin_base.restore_important_follower_config(config.slaveinput)
- plugin_base.configure_follower(
- config.slaveinput["follower_ident"]
- )
+ plugin_base.configure_follower(config.slaveinput["follower_ident"])
else:
- if config.option.write_idents and \
- os.path.exists(config.option.write_idents):
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
os.remove(config.option.write_idents)
plugin_base.pre_begin(config.option)
- plugin_base.set_coverage_flag(bool(getattr(config.option,
- "cov_source", False)))
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
plugin_base.set_skip_test(pytest.skip.Exception)
@@ -94,10 +107,12 @@ if has_xdist:
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
from sqlalchemy.testing import provision
+
provision.create_follower_db(node.slaveinput["follower_ident"])
def pytest_testnodedown(node, error):
from sqlalchemy.testing import provision
+
provision.drop_follower_db(node.slaveinput["follower_ident"])
@@ -114,19 +129,22 @@ def pytest_collection_modifyitems(session, config, items):
rebuilt_items = collections.defaultdict(list)
items[:] = [
- item for item in
- items if isinstance(item.parent, pytest.Instance)
- and not item.parent.parent.name.startswith("_")]
+ item
+ for item in items
+ if isinstance(item.parent, pytest.Instance)
+ and not item.parent.parent.name.startswith("_")
+ ]
test_classes = set(item.parent for item in items)
for test_class in test_classes:
for sub_cls in plugin_base.generate_sub_tests(
- test_class.cls, test_class.parent.module):
+ test_class.cls, test_class.parent.module
+ ):
if sub_cls is not test_class.cls:
list_ = rebuilt_items[test_class.cls]
for inst in pytest.Class(
- sub_cls.__name__,
- parent=test_class.parent.parent).collect():
+ sub_cls.__name__, parent=test_class.parent.parent
+ ).collect():
list_.extend(inst.collect())
newitems = []
@@ -139,23 +157,29 @@ def pytest_collection_modifyitems(session, config, items):
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
- items[:] = sorted(newitems, key=lambda item: (
- item.parent.parent.parent.name,
- item.parent.parent.name,
- item.name
- ))
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.parent.parent.parent.name,
+ item.parent.parent.name,
+ item.name,
+ ),
+ )
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(obj):
return pytest.Class(name, parent=collector)
- elif inspect.isfunction(obj) and \
- isinstance(collector, pytest.Instance) and \
- plugin_base.want_method(collector.cls, obj):
+ elif (
+ inspect.isfunction(obj)
+ and isinstance(collector, pytest.Instance)
+ and plugin_base.want_method(collector.cls, obj)
+ ):
return pytest.Function(name, parent=collector)
else:
return []
+
_current_class = None
@@ -180,6 +204,7 @@ def pytest_runtest_setup(item):
global _current_class
class_teardown(item.parent.parent)
_current_class = None
+
item.parent.parent.addfinalizer(finalize)
test_setup(item)
@@ -194,8 +219,9 @@ def pytest_runtest_teardown(item):
def test_setup(item):
- plugin_base.before_test(item, item.parent.module.__name__,
- item.parent.cls, item.name)
+ plugin_base.before_test(
+ item, item.parent.module.__name__, item.parent.cls, item.name
+ )
def test_teardown(item):