summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/__init__.py61
-rw-r--r--lib/sqlalchemy/testing/assertions.py182
-rw-r--r--lib/sqlalchemy/testing/assertsql.py148
-rw-r--r--lib/sqlalchemy/testing/config.py6
-rw-r--r--lib/sqlalchemy/testing/engines.py60
-rw-r--r--lib/sqlalchemy/testing/entities.py23
-rw-r--r--lib/sqlalchemy/testing/exclusions.py123
-rw-r--r--lib/sqlalchemy/testing/fixtures.py85
-rw-r--r--lib/sqlalchemy/testing/mock.py3
-rw-r--r--lib/sqlalchemy/testing/pickleable.py36
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py7
-rw-r--r--lib/sqlalchemy/testing/plugin/noseplugin.py23
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py359
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py104
-rw-r--r--lib/sqlalchemy/testing/profiling.py71
-rw-r--r--lib/sqlalchemy/testing/provision.py86
-rw-r--r--lib/sqlalchemy/testing/replay_fixture.py81
-rw-r--r--lib/sqlalchemy/testing/requirements.py78
-rw-r--r--lib/sqlalchemy/testing/runner.py2
-rw-r--r--lib/sqlalchemy/testing/schema.py68
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py132
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py50
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py54
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py221
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py762
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py242
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py483
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py136
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py702
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py37
-rw-r--r--lib/sqlalchemy/testing/util.py66
-rw-r--r--lib/sqlalchemy/testing/warnings.py21
33 files changed, 2399 insertions, 2114 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 413a492b8..f46ca4528 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -10,23 +10,62 @@ from .warnings import assert_warnings
from . import config
-from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
- fails_on, fails_on_everything_except, skip, only_on, exclude, \
- against as _against, _server_version, only_if, fails
+from .exclusions import (
+ db_spec,
+ _is_excluded,
+ fails_if,
+ skip_if,
+ future,
+ fails_on,
+ fails_on_everything_except,
+ skip,
+ only_on,
+ exclude,
+ against as _against,
+ _server_version,
+ only_if,
+ fails,
+)
def against(*queries):
return _against(config._current, *queries)
-from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
- eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
- assert_raises_message, AssertsCompiledSQL, ComparesTables, \
- AssertsExecutionResults, expect_deprecated, expect_warnings, \
- in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false
-from .util import run_as_contextmanager, rowset, fail, \
- provide_metadata, adict, force_drop_names, \
- teardown_events
+from .assertions import (
+ emits_warning,
+ emits_warning_on,
+ uses_deprecated,
+ eq_,
+ ne_,
+ le_,
+ is_,
+ is_not_,
+ startswith_,
+ assert_raises,
+ assert_raises_message,
+ AssertsCompiledSQL,
+ ComparesTables,
+ AssertsExecutionResults,
+ expect_deprecated,
+ expect_warnings,
+ in_,
+ not_in_,
+ eq_ignore_whitespace,
+ eq_regex,
+ is_true,
+ is_false,
+)
+
+from .util import (
+ run_as_contextmanager,
+ rowset,
+ fail,
+ provide_metadata,
+ adict,
+ force_drop_names,
+ teardown_events,
+)
crashes = skip
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e42376921..73ab4556a 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -86,6 +86,7 @@ def emits_warning_on(db, *messages):
were in fact seen.
"""
+
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
@@ -114,12 +115,14 @@ def uses_deprecated(*messages):
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
+
return decorate
@contextlib.contextmanager
-def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
- py2konly=False):
+def _expect_warnings(
+ exc_cls, messages, regex=True, assert_=True, py2konly=False
+):
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
@@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
return
for filter_ in filters:
- if (regex and filter_.match(msg)) or \
- (not regex and filter_ == msg):
+ if (regex and filter_.match(msg)) or (
+ not regex and filter_ == msg
+ ):
seen.discard(filter_)
break
else:
@@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
yield
if assert_ and (not py2konly or not compat.py3k):
- assert not seen, "Warnings were not seen: %s" % \
- ", ".join("%r" % (s.pattern if regex else s) for s in seen)
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
def global_cleanup_assertions():
@@ -170,6 +175,7 @@ def global_cleanup_assertions():
"""
_assert_no_stray_pool_connections()
+
_STRAY_CONNECTION_FAILURES = 0
@@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections():
# OK, let's be somewhat forgiving.
_STRAY_CONNECTION_FAILURES += 1
- print("Encountered a stray connection in test cleanup: %s"
- % str(pool._refs))
+ print(
+ "Encountered a stray connection in test cleanup: %s"
+ % str(pool._refs)
+ )
# then do a real GC sweep. We shouldn't even be here
# so a single sweep should really be doing it, otherwise
# there's probably a real unreachable cycle somewhere.
@@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections():
pool._refs.clear()
_STRAY_CONNECTION_FAILURES = 0
warnings.warn(
- "Stray connection refused to leave "
- "after gc.collect(): %s" % err)
+ "Stray connection refused to leave " "after gc.collect(): %s" % err
+ )
elif _STRAY_CONNECTION_FAILURES > 10:
assert False, "Encountered more than 10 stray connections"
_STRAY_CONNECTION_FAILURES = 0
@@ -263,14 +271,16 @@ def not_in_(a, b, msg=None):
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
+ a,
+ fragment,
+ )
def eq_ignore_whitespace(a, b, msg=None):
- a = re.sub(r'^\s+?|\n', "", a)
- a = re.sub(r' {2,}', " ", a)
- b = re.sub(r'^\s+?|\n', "", b)
- b = re.sub(r' {2,}', " ", b)
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
@@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls as e:
- assert re.search(
- msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
- print(util.text_type(e).encode('utf-8'))
+ assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
+ msg,
+ e,
+ )
+ print(util.text_type(e).encode("utf-8"))
+
class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None,
- checkparams=None, dialect=None,
- checkpositional=None,
- check_prefetch=None,
- use_default_dialect=False,
- allow_dialect_select=False,
- literal_binds=False,
- schema_translate_map=None):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ literal_binds=False,
+ schema_translate_map=None,
+ ):
if use_default_dialect:
dialect = default.DefaultDialect()
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
- dialect = getattr(self, '__dialect__', None)
+ dialect = getattr(self, "__dialect__", None)
if dialect is None:
dialect = config.db.dialect
- elif dialect == 'default':
+ elif dialect == "default":
dialect = default.DefaultDialect()
- elif dialect == 'default_enhanced':
+ elif dialect == "default_enhanced":
dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
@@ -325,13 +344,13 @@ class AssertsCompiledSQL(object):
compile_kwargs = {}
if schema_translate_map:
- kw['schema_translate_map'] = schema_translate_map
+ kw["schema_translate_map"] = schema_translate_map
if params is not None:
- kw['column_keys'] = list(params)
+ kw["column_keys"] = list(params)
if literal_binds:
- compile_kwargs['literal_binds'] = True
+ compile_kwargs["literal_binds"] = True
if isinstance(clause, orm.Query):
context = clause._compile_context()
@@ -343,25 +362,27 @@ class AssertsCompiledSQL(object):
clause = stmt_mock.mock_calls[0][1][0]
if compile_kwargs:
- kw['compile_kwargs'] = compile_kwargs
+ kw["compile_kwargs"] = compile_kwargs
c = clause.compile(dialect=dialect, **kw)
- param_str = repr(getattr(c, 'params', {}))
+ param_str = repr(getattr(c, "params", {}))
if util.py3k:
- param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
print(
- ("\nSQL String:\n" +
- util.text_type(c) +
- param_str).encode('utf-8'))
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
else:
print(
- "\nSQL String:\n" +
- util.text_type(c).encode('utf-8') +
- param_str)
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
- cc = re.sub(r'[\n\t]', '', util.text_type(c))
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
@@ -375,7 +396,6 @@ class AssertsCompiledSQL(object):
class ComparesTables(object):
-
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
@@ -386,8 +406,10 @@ class ComparesTables(object):
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
- assert isinstance(reflected_c.type, type(c.type)), \
- msg % (reflected_c.type, c.type)
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
else:
self.assert_types_base(reflected_c, c)
@@ -396,20 +418,22 @@ class ComparesTables(object):
eq_(
{f.column.name for f in c.foreign_keys},
- {f.column.name for f in reflected_c.foreign_keys}
+ {f.column.name for f in reflected_c.foreign_keys},
)
if c.server_default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
- assert c1.type._compare_type_affinity(c2.type),\
- "On column %r, type '%s' doesn't correspond to type '%s'" % \
- (c1.name, c1.type, c2.type)
+ assert c1.type._compare_type_affinity(c2.type), (
+ "On column %r, type '%s' doesn't correspond to type '%s'"
+ % (c1.name, c1.type, c2.type)
+ )
class AssertsExecutionResults(object):
@@ -419,15 +443,19 @@ class AssertsExecutionResults(object):
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
+ self.assert_(
+ len(result) == len(list),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
@@ -435,9 +463,11 @@ class AssertsExecutionResults(object):
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
@@ -453,14 +483,19 @@ class AssertsExecutionResults(object):
found = util.IdentitySet(result)
expected = {immutabledict(e) for e in expected}
- for wrong in util.itertools_filterfalse(lambda o:
- isinstance(o, cls), found):
- fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
if len(found) != len(expected):
- fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
NOVALUE = object()
@@ -469,7 +504,8 @@ class AssertsExecutionResults(object):
if isinstance(value, tuple):
try:
self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
+ getattr(obj, key), value[0], *value[1]
+ )
except AssertionError:
return False
else:
@@ -484,8 +520,9 @@ class AssertsExecutionResults(object):
break
else:
fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
return True
def sql_execution_asserter(self, db=None):
@@ -505,9 +542,9 @@ class AssertsExecutionResults(object):
newrules = []
for rule in rules:
if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.CompiledSQL(k, v) for k, v in rule.items()
- ])
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
@@ -516,7 +553,8 @@ class AssertsExecutionResults(object):
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
- db, callable_, assertsql.CountStatements(count))
+ db, callable_, assertsql.CountStatements(count)
+ )
def assert_multiple_sql_count(self, dbs, callable_, counts):
recs = [
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 7a525589d..d8e924cb6 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -26,8 +26,10 @@ class AssertRule(object):
pass
def no_more_statements(self):
- assert False, 'All statements are complete, but pending '\
- 'assertion rules remain'
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
class SQLMatchRule(AssertRule):
@@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule):
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters):
- self.errormessage = \
- "Testing for exact SQL %s parameters %s received %s %s" % (
- self.statement, self.params,
- stmt.statement, stmt.parameters
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
)
+ )
else:
execute_observed.statements.pop(0)
self.is_consumed = True
@@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
-
- def __init__(self, statement, params=None, dialect='default'):
+ def __init__(self, statement, params=None, dialect="default"):
self.statement = statement
self.params = params
self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- if self.dialect == 'default':
+ if self.dialect == "default":
return DefaultDialect()
else:
# ugh
- if self.dialect == 'postgresql':
- params = {'implicit_returning': True}
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
@@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
+ )
else:
- compiled = (
- context.compiled.statement.compile(
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- inline=context.compiled.inline,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ inline=context.compiled.inline,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
)
- _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
- compiled.construct_params(m) for m in parameters]
+ compiled.construct_params(m) for m in parameters
+ ]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
- _received_statement, _received_parameters = \
- self._received_statement(execute_observed)
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
@@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule):
for param_key in param:
# a key in param did not match current
# 'received'
- if param_key not in received or \
- received[param_key] != param[param_key]:
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
break
else:
# all keys in param matched 'received';
@@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule):
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
- 'received_statement': _received_statement,
- 'received_parameters': _received_parameters
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
}
def _all_params(self, context):
@@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule):
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement %r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.statement.replace('%', '%%'), expected_params
- )
+ "Testing for compiled statement %r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (self.statement.replace("%", "%%"), expected_params)
)
@@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL):
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
- self.dialect = 'default'
+ self.dialect = "default"
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement ~%r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.orig_regex, expected_params
- )
+ "Testing for compiled statement ~%r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r" % (self.orig_regex, expected_params)
)
def _compare_sql(self, execute_observed, received_statement):
@@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r'[\n\t]', '', real_stmt)
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
- received_stmt, received_params = super(DialectSQL, self).\
- _received_statement(execute_observed)
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
@@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL):
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt)
+ "statements actually invoked" % received_stmt
+ )
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
- if paramstyle == 'pyformat':
- stmt = re.sub(
- r':([\w_]+)', r"%(\1)s", stmt)
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
# positional params
repl = None
- if paramstyle == 'qmark':
+ if paramstyle == "qmark":
repl = "?"
- elif paramstyle == 'format':
+ elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == 'numeric':
+ elif paramstyle == "numeric":
repl = None
- stmt = re.sub(r':([\w_]+)', repl, stmt)
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
return received_statement == stmt
class CountStatements(AssertRule):
-
def __init__(self, count):
self.count = count
self._statement_count = 0
@@ -256,12 +264,13 @@ class CountStatements(AssertRule):
def no_more_statements(self):
if self.count != self._statement_count:
- assert False, 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
class AllOf(AssertRule):
-
def __init__(self, *rules):
self.rules = set(rules)
@@ -283,7 +292,6 @@ class AllOf(AssertRule):
class EachOf(AssertRule):
-
def __init__(self, *rules):
self.rules = list(rules)
@@ -309,7 +317,6 @@ class EachOf(AssertRule):
class Or(AllOf):
-
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
@@ -331,7 +338,8 @@ class SQLExecuteObserved(object):
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"])
+ ["statement", "parameters", "context", "executemany"],
+ )
):
pass
@@ -374,21 +382,25 @@ def assert_engine(engine):
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(conn, cursor, statement, parameters,
- context, executemany):
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
- if asserter.accumulated and \
- asserter.accumulated[-1].context is context:
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
- statement, parameters, context, executemany)
+ statement, parameters, context, executemany
+ )
)
try:
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index e9cfb3de9..1ff282af5 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -64,8 +64,9 @@ class Config(object):
assert _current, "Can't push without a default Config set up"
cls.push(
Config(
- db, _current.db_opts, _current.options, _current.file_config),
- namespace
+ db, _current.db_opts, _current.options, _current.file_config
+ ),
+ namespace,
)
@classmethod
@@ -94,4 +95,3 @@ class Config(object):
def skip_test(msg):
raise _skip_test_exception(msg)
-
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index d17e30edf..074e3b338 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -16,7 +16,6 @@ import warnings
class ConnectionKiller(object):
-
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
@@ -39,8 +38,8 @@ class ConnectionKiller(object):
fn()
except Exception as e:
warnings.warn(
- "testing_reaper couldn't "
- "rollback/close connection: %s" % e)
+ "testing_reaper couldn't " "rollback/close connection: %s" % e
+ )
def rollback_all(self):
for rec in list(self.proxy_refs):
@@ -97,18 +96,19 @@ class ConnectionKiller(object):
if rec.is_valid:
assert False
+
testing_reaper = ConnectionKiller()
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
- if hasattr(bind, 'close'):
+ if hasattr(bind, "close"):
bind.close()
if not config.db.dialect.supports_alter:
from . import assertions
- with assertions.expect_warnings(
- "Can't sort tables", assert_=False):
+
+ with assertions.expect_warnings("Can't sort tables", assert_=False):
metadata.drop_all(bind)
else:
metadata.drop_all(bind)
@@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw):
def all_dialects(exclude=None):
import sqlalchemy.databases as d
+
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
- mod = getattr(__import__(
- 'sqlalchemy.databases.%s' % name).databases, name)
+ mod = getattr(
+ __import__("sqlalchemy.databases.%s" % name).databases, name
+ )
yield mod.dialect()
class ReconnectFixture(object):
-
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
@@ -191,8 +192,8 @@ class ReconnectFixture(object):
fn()
except Exception as e:
warnings.warn(
- "ReconnectFixture couldn't "
- "close connection: %s" % e)
+ "ReconnectFixture couldn't " "close connection: %s" % e
+ )
def shutdown(self, stop=False):
# TODO: this doesn't cover all cases
@@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None):
dbapi = config.db.dialect.dbapi
if not options:
options = {}
- options['module'] = ReconnectFixture(dbapi)
+ options["module"] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
@@ -238,7 +239,7 @@ def testing_engine(url=None, options=None):
if not options:
use_reaper = True
else:
- use_reaper = options.pop('use_reaper', True)
+ use_reaper = options.pop("use_reaper", True)
url = url or config.db.url
@@ -253,15 +254,15 @@ def testing_engine(url=None, options=None):
default_opt.update(options)
engine = create_engine(url, **options)
- engine._has_events = True # enable event blocks, helps with profiling
+ engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
if use_reaper:
- event.listen(engine.pool, 'connect', testing_reaper.connect)
- event.listen(engine.pool, 'checkout', testing_reaper.checkout)
- event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
+ event.listen(engine.pool, "connect", testing_reaper.connect)
+ event.listen(engine.pool, "checkout", testing_reaper.checkout)
+ event.listen(engine.pool, "invalidate", testing_reaper.invalidate)
testing_reaper.add_engine(engine)
return engine
@@ -290,19 +291,17 @@ def mock_engine(dialect_name=None):
buffer.append(sql)
def assert_sql(stmts):
- recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
assert recv == stmts, recv
def print_sql():
d = engine.dialect
- return "\n".join(
- str(s.compile(dialect=d))
- for s in engine.mock
- )
-
- engine = create_engine(dialect_name + '://',
- strategy='mock', executor=executor)
- assert not hasattr(engine, 'mock')
+ return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
+
+ engine = create_engine(
+ dialect_name + "://", strategy="mock", executor=executor
+ )
+ assert not hasattr(engine, "mock")
engine.mock = buffer
engine.assert_sql = assert_sql
engine.print_sql = print_sql
@@ -358,14 +357,15 @@ class DBAPIProxyConnection(object):
return getattr(self.conn, key)
-def proxying_engine(conn_cls=DBAPIProxyConnection,
- cursor_cls=DBAPIProxyCursor):
+def proxying_engine(
+ conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
+):
"""Produce an engine that provides proxy hooks for
common methods.
"""
+
def mock_conn():
return conn_cls(config.db, cursor_cls)
- return testing_engine(options={'creator': mock_conn})
-
+ return testing_engine(options={"creator": mock_conn})
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
index b634735fc..42c42149c 100644
--- a/lib/sqlalchemy/testing/entities.py
+++ b/lib/sqlalchemy/testing/entities.py
@@ -12,7 +12,6 @@ _repr_stack = set()
class BasicEntity(object):
-
def __init__(self, **kw):
for key, value in kw.items():
setattr(self, key, value)
@@ -24,17 +23,22 @@ class BasicEntity(object):
try:
return "%s(%s)" % (
(self.__class__.__name__),
- ', '.join(["%s=%r" % (key, getattr(self, key))
- for key in sorted(self.__dict__.keys())
- if not key.startswith('_')]))
+ ", ".join(
+ [
+ "%s=%r" % (key, getattr(self, key))
+ for key in sorted(self.__dict__.keys())
+ if not key.startswith("_")
+ ]
+ ),
+ )
finally:
_repr_stack.remove(id(self))
+
_recursion_stack = set()
class ComparableEntity(BasicEntity):
-
def __hash__(self):
return hash(self.__class__)
@@ -75,7 +79,7 @@ class ComparableEntity(BasicEntity):
b = other
for attr in list(a.__dict__):
- if attr.startswith('_'):
+ if attr.startswith("_"):
continue
value = getattr(a, attr)
@@ -85,9 +89,10 @@ class ComparableEntity(BasicEntity):
except (AttributeError, sa_exc.UnboundExecutionError):
return False
- if hasattr(value, '__iter__'):
- if hasattr(value, '__getitem__') and not hasattr(
- value, 'keys'):
+ if hasattr(value, "__iter__"):
+ if hasattr(value, "__getitem__") and not hasattr(
+ value, "keys"
+ ):
if list(value) != list(battr):
return False
else:
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
index 512fffb3b..9ed9e42c3 100644
--- a/lib/sqlalchemy/testing/exclusions.py
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -16,6 +16,7 @@ from . import config
from .. import util
from ..util import decorator
+
def skip_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
@@ -70,15 +71,15 @@ class compound(object):
def matching_config_reasons(self, config):
return [
- predicate._as_string(config) for predicate
- in self.skips.union(self.fails)
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
if predicate(config)
]
def include_test(self, include_tags, exclude_tags):
return bool(
- not self.tags.intersection(exclude_tags) and
- (not include_tags or self.tags.intersection(include_tags))
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
)
def _extend(self, other):
@@ -87,13 +88,14 @@ class compound(object):
self.tags.update(other.tags)
def __call__(self, fn):
- if hasattr(fn, '_sa_exclusion_extend'):
+ if hasattr(fn, "_sa_exclusion_extend"):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
+
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@@ -113,10 +115,7 @@ class compound(object):
def _do(self, cfg, fn, *args, **kw):
for skip in self.skips:
if skip(cfg):
- msg = "'%s' : %s" % (
- fn.__name__,
- skip._as_string(cfg)
- )
+ msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg))
config.skip_test(msg)
try:
@@ -127,16 +126,20 @@ class compound(object):
self._expect_success(cfg, name=fn.__name__)
return return_value
- def _expect_failure(self, config, ex, name='block'):
+ def _expect_failure(self, config, ex, name="block"):
for fail in self.fails:
if fail(config):
- print(("%s failed as expected (%s): %s " % (
- name, fail._as_string(config), str(ex))))
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str(ex))
+ )
+ )
break
else:
util.raise_from_cause(ex)
- def _expect_success(self, config, name='block'):
+ def _expect_success(self, config, name="block"):
if not self.fails:
return
for fail in self.fails:
@@ -144,13 +147,12 @@ class compound(object):
break
else:
raise AssertionError(
- "Unexpected success for '%s' (%s)" %
- (
+ "Unexpected success for '%s' (%s)"
+ % (
name,
" and ".join(
- fail._as_string(config)
- for fail in self.fails
- )
+ fail._as_string(config) for fail in self.fails
+ ),
)
)
@@ -186,21 +188,24 @@ class Predicate(object):
return predicate
elif isinstance(predicate, (list, set)):
return OrPredicate(
- [cls.as_predicate(pred) for pred in predicate],
- description)
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, util.string_types):
tokens = re.match(
- r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate)
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
if not tokens:
raise ValueError(
- "Couldn't locate DB name in predicate: %r" % predicate)
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
db = tokens.group(1)
op = tokens.group(2)
spec = (
tuple(int(d) for d in tokens.group(3).split("."))
- if tokens.group(3) else None
+ if tokens.group(3)
+ else None
)
return SpecPredicate(db, op, spec, description=description)
@@ -215,11 +220,13 @@ class Predicate(object):
bool_ = not negate
return self.description % {
"driver": config.db.url.get_driver_name()
- if config else "<no driver>",
+ if config
+ else "<no driver>",
"database": config.db.url.get_backend_name()
- if config else "<no database>",
+ if config
+ else "<no database>",
"doesnt_support": "doesn't support" if bool_ else "does support",
- "does_support": "does support" if bool_ else "doesn't support"
+ "does_support": "does support" if bool_ else "doesn't support",
}
def _as_string(self, config=None, negate=False):
@@ -246,21 +253,21 @@ class SpecPredicate(Predicate):
self.description = description
_ops = {
- '<': operator.lt,
- '>': operator.gt,
- '==': operator.eq,
- '!=': operator.ne,
- '<=': operator.le,
- '>=': operator.ge,
- 'in': operator.contains,
- 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, config):
engine = config.db
if "+" in self.db:
- dialect, driver = self.db.split('+')
+ dialect, driver = self.db.split("+")
else:
dialect, driver = self.db, None
@@ -273,8 +280,9 @@ class SpecPredicate(Predicate):
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
- oper = hasattr(self.op, '__call__') and self.op \
- or self._ops[self.op]
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
return oper(version, self.spec)
else:
return True
@@ -289,17 +297,9 @@ class SpecPredicate(Predicate):
return "%s" % self.db
else:
if negate:
- return "not %s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "not %s %s %s" % (self.db, self.op, self.spec)
else:
- return "%s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "%s %s %s" % (self.db, self.op, self.spec)
class LambdaPredicate(Predicate):
@@ -356,8 +356,9 @@ class OrPredicate(Predicate):
conjunction = " and "
else:
conjunction = " or "
- return conjunction.join(p._as_string(config, negate=negate)
- for p in self.predicates)
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
def _negation_str(self, config):
if self.description is not None:
@@ -387,7 +388,7 @@ def _server_version(engine):
# force metadata to be retrieved
conn = engine.connect()
- version = getattr(engine.dialect, 'server_version_info', None)
+ version = getattr(engine.dialect, "server_version_info", None)
if version is None:
version = ()
conn.close()
@@ -395,9 +396,7 @@ def _server_version(engine):
def db_spec(*dbs):
- return OrPredicate(
- [Predicate.as_predicate(db) for db in dbs]
- )
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
def open():
@@ -422,11 +421,7 @@ def fails_on(db, reason=None):
def fails_on_everything_except(*dbs):
- return succeeds_if(
- OrPredicate([
- Predicate.as_predicate(db) for db in dbs
- ])
- )
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
def skip(db, reason=None):
@@ -435,8 +430,9 @@ def skip(db, reason=None):
def only_on(dbs, reason=None):
return only_if(
- OrPredicate([Predicate.as_predicate(db, reason)
- for db in util.to_list(dbs)])
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
)
@@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None):
def against(config, *queries):
assert queries, "no queries sent!"
- return OrPredicate([
- Predicate.as_predicate(query)
- for query in queries
- ])(config)
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index dd0fa5a48..98184cdd4 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -54,19 +54,19 @@ class TestBase(object):
class TablesTest(TestBase):
# 'once', None
- run_setup_bind = 'once'
+ run_setup_bind = "once"
# 'once', 'each', None
- run_define_tables = 'once'
+ run_define_tables = "once"
# 'once', 'each', None
- run_create_tables = 'once'
+ run_create_tables = "once"
# 'once', 'each', None
- run_inserts = 'each'
+ run_inserts = "each"
# 'each', None
- run_deletes = 'each'
+ run_deletes = "each"
# 'once', None
run_dispose_bind = None
@@ -86,10 +86,10 @@ class TablesTest(TestBase):
@classmethod
def _init_class(cls):
- if cls.run_define_tables == 'each':
- if cls.run_create_tables == 'once':
- cls.run_create_tables = 'each'
- assert cls.run_inserts in ('each', None)
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
cls.other = adict()
cls.tables = adict()
@@ -100,40 +100,40 @@ class TablesTest(TestBase):
@classmethod
def _setup_once_inserts(cls):
- if cls.run_inserts == 'once':
+ if cls.run_inserts == "once":
cls._load_fixtures()
cls.insert_data()
@classmethod
def _setup_once_tables(cls):
- if cls.run_define_tables == 'once':
+ if cls.run_define_tables == "once":
cls.define_tables(cls.metadata)
- if cls.run_create_tables == 'once':
+ if cls.run_create_tables == "once":
cls.metadata.create_all(cls.bind)
cls.tables.update(cls.metadata.tables)
def _setup_each_tables(self):
- if self.run_define_tables == 'each':
+ if self.run_define_tables == "each":
self.tables.clear()
- if self.run_create_tables == 'each':
+ if self.run_create_tables == "each":
drop_all_tables(self.metadata, self.bind)
self.metadata.clear()
self.define_tables(self.metadata)
- if self.run_create_tables == 'each':
+ if self.run_create_tables == "each":
self.metadata.create_all(self.bind)
self.tables.update(self.metadata.tables)
- elif self.run_create_tables == 'each':
+ elif self.run_create_tables == "each":
drop_all_tables(self.metadata, self.bind)
self.metadata.create_all(self.bind)
def _setup_each_inserts(self):
- if self.run_inserts == 'each':
+ if self.run_inserts == "each":
self._load_fixtures()
self.insert_data()
def _teardown_each_tables(self):
# no need to run deletes if tables are recreated on setup
- if self.run_define_tables != 'each' and self.run_deletes == 'each':
+ if self.run_define_tables != "each" and self.run_deletes == "each":
with self.bind.connect() as conn:
for table in reversed(self.metadata.sorted_tables):
try:
@@ -141,7 +141,8 @@ class TablesTest(TestBase):
except sa.exc.DBAPIError as ex:
util.print_(
("Error emptying table %s: %r" % (table, ex)),
- file=sys.stderr)
+ file=sys.stderr,
+ )
def setup(self):
self._setup_each_tables()
@@ -155,7 +156,7 @@ class TablesTest(TestBase):
if cls.run_create_tables:
drop_all_tables(cls.metadata, cls.bind)
- if cls.run_dispose_bind == 'once':
+ if cls.run_dispose_bind == "once":
cls.dispose_bind(cls.bind)
cls.metadata.bind = None
@@ -173,9 +174,9 @@ class TablesTest(TestBase):
@classmethod
def dispose_bind(cls, bind):
- if hasattr(bind, 'dispose'):
+ if hasattr(bind, "dispose"):
bind.dispose()
- elif hasattr(bind, 'close'):
+ elif hasattr(bind, "close"):
bind.close()
@classmethod
@@ -212,8 +213,12 @@ class TablesTest(TestBase):
continue
cls.bind.execute(
table.insert(),
- [dict(zip(headers[table], column_values))
- for column_values in rows[table]])
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
from sqlalchemy import event
@@ -236,7 +241,6 @@ class RemovesEvents(object):
class _ORMTest(object):
-
@classmethod
def teardown_class(cls):
sa.orm.session.Session.close_all()
@@ -249,10 +253,10 @@ class ORMTest(_ORMTest, TestBase):
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
- run_setup_classes = 'once'
+ run_setup_classes = "once"
# 'once', 'each', None
- run_setup_mappers = 'each'
+ run_setup_mappers = "each"
classes = None
@@ -292,20 +296,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
@classmethod
def _setup_once_classes(cls):
- if cls.run_setup_classes == 'once':
+ if cls.run_setup_classes == "once":
cls._with_register_classes(cls.setup_classes)
@classmethod
def _setup_once_mappers(cls):
- if cls.run_setup_mappers == 'once':
+ if cls.run_setup_mappers == "once":
cls._with_register_classes(cls.setup_mappers)
def _setup_each_mappers(self):
- if self.run_setup_mappers == 'each':
+ if self.run_setup_mappers == "each":
self._with_register_classes(self.setup_mappers)
def _setup_each_classes(self):
- if self.run_setup_classes == 'each':
+ if self.run_setup_classes == "each":
self._with_register_classes(self.setup_classes)
@classmethod
@@ -339,11 +343,11 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# some tests create mappers in the test bodies
# and will define setup_mappers as None -
# clear mappers in any case
- if self.run_setup_mappers != 'once':
+ if self.run_setup_mappers != "once":
sa.orm.clear_mappers()
def _teardown_each_classes(self):
- if self.run_setup_classes != 'once':
+ if self.run_setup_classes != "once":
self.classes.clear()
@classmethod
@@ -356,8 +360,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
class DeclarativeMappedTest(MappedTest):
- run_setup_classes = 'once'
- run_setup_mappers = 'once'
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
@classmethod
def _setup_once_tables(cls):
@@ -370,15 +374,16 @@ class DeclarativeMappedTest(MappedTest):
class FindFixtureDeclarative(DeclarativeMeta):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
- return DeclarativeMeta.__init__(
- cls, classname, bases, dict_)
+ return DeclarativeMeta.__init__(cls, classname, bases, dict_)
class DeclarativeBasic(object):
__table_cls__ = schema.Table
- _DeclBase = declarative_base(metadata=cls.metadata,
- metaclass=FindFixtureDeclarative,
- cls=DeclarativeBasic)
+ _DeclBase = declarative_base(
+ metadata=cls.metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic,
+ )
cls.DeclarativeBasic = _DeclBase
fn()
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
index ea0a8da82..dc530af5e 100644
--- a/lib/sqlalchemy/testing/mock.py
+++ b/lib/sqlalchemy/testing/mock.py
@@ -18,4 +18,5 @@ else:
except ImportError:
raise ImportError(
"SQLAlchemy's test suite requires the "
- "'mock' library as of 0.8.2.")
+ "'mock' library as of 0.8.2."
+ )
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
index 087fc1fe6..e84cbde44 100644
--- a/lib/sqlalchemy/testing/pickleable.py
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -46,29 +46,28 @@ class Parent(fixtures.ComparableEntity):
class Screen(object):
-
def __init__(self, obj, parent=None):
self.obj = obj
self.parent = parent
class Foo(object):
-
def __init__(self, moredata):
- self.data = 'im data'
- self.stuff = 'im stuff'
+ self.data = "im data"
+ self.stuff = "im stuff"
self.moredata = moredata
__hash__ = object.__hash__
def __eq__(self, other):
- return other.data == self.data and \
- other.stuff == self.stuff and \
- other.moredata == self.moredata
+ return (
+ other.data == self.data
+ and other.stuff == self.stuff
+ and other.moredata == self.moredata
+ )
class Bar(object):
-
def __init__(self, x, y):
self.x = x
self.y = y
@@ -76,35 +75,36 @@ class Bar(object):
__hash__ = object.__hash__
def __eq__(self, other):
- return other.__class__ is self.__class__ and \
- other.x == self.x and \
- other.y == self.y
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class OldSchool:
-
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
- return other.__class__ is self.__class__ and \
- other.x == self.x and \
- other.y == self.y
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
class OldSchoolWithoutCompare:
-
def __init__(self, x, y):
self.x = x
self.y = y
class BarWithoutCompare(object):
-
def __init__(self, x, y):
self.x = x
self.y = y
@@ -114,7 +114,6 @@ class BarWithoutCompare(object):
class NotComparable(object):
-
def __init__(self, data):
self.data = data
@@ -129,7 +128,6 @@ class NotComparable(object):
class BrokenComparable(object):
-
def __init__(self, data):
self.data = data
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
index 497fcb7e5..bb52c125c 100644
--- a/lib/sqlalchemy/testing/plugin/bootstrap.py
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0.
import os
import sys
-bootstrap_file = locals()['bootstrap_file']
-to_bootstrap = locals()['to_bootstrap']
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
if sys.version_info >= (3, 3):
from importlib import machinery
+
mod = machinery.SourceFileLoader(name, path).load_module()
else:
import imp
+
mod = imp.load_source(name, path)
return mod
+
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py
index 20ea61d89..0c28a5213 100644
--- a/lib/sqlalchemy/testing/plugin/noseplugin.py
+++ b/lib/sqlalchemy/testing/plugin/noseplugin.py
@@ -25,6 +25,7 @@ import sys
from nose.plugins import Plugin
import nose
+
fixtures = None
py3k = sys.version_info >= (3, 0)
@@ -33,7 +34,7 @@ py3k = sys.version_info >= (3, 0)
class NoseSQLAlchemy(Plugin):
enabled = True
- name = 'sqla_testing'
+ name = "sqla_testing"
score = 100
def options(self, parser, env=os.environ):
@@ -41,10 +42,14 @@ class NoseSQLAlchemy(Plugin):
opt = parser.add_option
def make_option(name, **kw):
- callback_ = kw.pop("callback", None) or kw.pop("zeroarg_callback", None)
+ callback_ = kw.pop("callback", None) or kw.pop(
+ "zeroarg_callback", None
+ )
if callback_:
+
def wrap_(option, opt_str, value, parser):
callback_(opt_str, value, parser)
+
kw["callback"] = wrap_
opt(name, **kw)
@@ -73,7 +78,7 @@ class NoseSQLAlchemy(Plugin):
def wantMethod(self, fn):
if py3k:
- if not hasattr(fn.__self__, 'cls'):
+ if not hasattr(fn.__self__, "cls"):
return False
cls = fn.__self__.cls
else:
@@ -84,24 +89,24 @@ class NoseSQLAlchemy(Plugin):
return plugin_base.want_class(cls)
def beforeTest(self, test):
- if not hasattr(test.test, 'cls'):
+ if not hasattr(test.test, "cls"):
return
plugin_base.before_test(
test,
test.test.cls.__module__,
- test.test.cls, test.test.method.__name__)
+ test.test.cls,
+ test.test.method.__name__,
+ )
def afterTest(self, test):
plugin_base.after_test(test)
def startContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.start_test_class(ctx)
def stopContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.stop_test_class(ctx)
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 0ffcae093..5d6bf2975 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -46,58 +46,130 @@ options = None
def setup_options(make_option):
- make_option("--log-info", action="callback", type="string", callback=_log,
- help="turn on info logging for <LOG> (multiple OK)")
- make_option("--log-debug", action="callback",
- type="string", callback=_log,
- help="turn on debug logging for <LOG> (multiple OK)")
- make_option("--db", action="append", type="string", dest="db",
- help="Use prefab database uri. Multiple OK, "
- "first one is run by default.")
- make_option('--dbs', action='callback', zeroarg_callback=_list_dbs,
- help="List available prefab dbs")
- make_option("--dburi", action="append", type="string", dest="dburi",
- help="Database uri. Multiple OK, "
- "first one is run by default.")
- make_option("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first")
- make_option("--backend-only", action="store_true", dest="backend_only",
- help="Run only tests marked with __backend__")
- make_option("--nomemory", action="store_true", dest="nomemory",
- help="Don't run memory profiling tests")
- make_option("--postgresql-templatedb", type="string",
- help="name of template database to use for PostgreSQL "
- "CREATE DATABASE (defaults to current database)")
- make_option("--low-connections", action="store_true",
- dest="low_connections",
- help="Use a low number of distinct connections - "
- "i.e. for Oracle TNS")
- make_option("--write-idents", type="string", dest="write_idents",
- help="write out generated follower idents to <file>, "
- "when -n<num> is used")
- make_option("--reversetop", action="store_true",
- dest="reversetop", default=False,
- help="Use a random-ordering set implementation in the ORM "
- "(helps reveal dependency issues)")
- make_option("--requirements", action="callback", type="string",
- callback=_requirements_opt,
- help="requirements class for testing, overrides setup.cfg")
- make_option("--with-cdecimal", action="store_true",
- dest="cdecimal", default=False,
- help="Monkeypatch the cdecimal library into Python 'decimal' "
- "for all tests")
- make_option("--include-tag", action="callback", callback=_include_tag,
- type="string",
- help="Include tests with tag <tag>")
- make_option("--exclude-tag", action="callback", callback=_exclude_tag,
- type="string",
- help="Exclude tests with tag <tag>")
- make_option("--write-profiles", action="store_true",
- dest="write_profiles", default=False,
- help="Write/update failing profiling data.")
- make_option("--force-write-profiles", action="store_true",
- dest="force_write_profiles", default=False,
- help="Unconditionally write/update profiling data.")
+ make_option(
+ "--log-info",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type="string",
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type="string",
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__",
+ )
+ make_option(
+ "--nomemory",
+ action="store_true",
+ dest="nomemory",
+ help="Don't run memory profiling tests",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type="string",
+ help="name of template database to use for PostgreSQL "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type="string",
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type="string",
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type="string",
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type="string",
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--write-profiles",
+ action="store_true",
+ dest="write_profiles",
+ default=False,
+ help="Write/update failing profiling data.",
+ )
+ make_option(
+ "--force-write-profiles",
+ action="store_true",
+ dest="force_write_profiles",
+ default=False,
+ help="Unconditionally write/update profiling data.",
+ )
def configure_follower(follower_ident):
@@ -108,6 +180,7 @@ def configure_follower(follower_ident):
"""
from sqlalchemy.testing import provision
+
provision.FOLLOWER_IDENT = follower_ident
@@ -121,9 +194,9 @@ def memoize_important_follower_config(dict_):
callables, so we have to just copy all of that over.
"""
- dict_['memoized_config'] = {
- 'include_tags': include_tags,
- 'exclude_tags': exclude_tags
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
}
@@ -134,14 +207,14 @@ def restore_important_follower_config(dict_):
"""
global include_tags, exclude_tags
- include_tags.update(dict_['memoized_config']['include_tags'])
- exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
def read_config():
global file_config
file_config = configparser.ConfigParser()
- file_config.read(['setup.cfg', 'test.cfg'])
+ file_config.read(["setup.cfg", "test.cfg"])
def pre_begin(opt):
@@ -155,6 +228,7 @@ def pre_begin(opt):
def set_coverage_flag(value):
options.has_coverage = value
+
_skip_test_exception = None
@@ -171,34 +245,33 @@ def post_begin():
# late imports, has to happen after config as well
# as nose plugins like coverage
- global util, fixtures, engines, exclusions, \
- assertions, warnings, profiling,\
- config, testing
- from sqlalchemy import testing # noqa
+ global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing
+ from sqlalchemy import testing # noqa
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
- from sqlalchemy.testing import assertions, warnings, profiling # noqa
+ from sqlalchemy.testing import assertions, warnings, profiling # noqa
from sqlalchemy.testing import config # noqa
from sqlalchemy import util # noqa
- warnings.setup_filters()
+ warnings.setup_filters()
def _log(opt_str, value, parser):
global logging
if not logging:
import logging
+
logging.basicConfig()
- if opt_str.endswith('-info'):
+ if opt_str.endswith("-info"):
logging.getLogger(value).setLevel(logging.INFO)
- elif opt_str.endswith('-debug'):
+ elif opt_str.endswith("-debug"):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
print("Available --db options (use --dburi to override)")
- for macro in sorted(file_config.options('db')):
- print("%20s\t%s" % (macro, file_config.get('db', macro)))
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
sys.exit(0)
@@ -207,11 +280,12 @@ def _requirements_opt(opt_str, value, parser):
def _exclude_tag(opt_str, value, parser):
- exclude_tags.add(value.replace('-', '_'))
+ exclude_tags.add(value.replace("-", "_"))
def _include_tag(opt_str, value, parser):
- include_tags.add(value.replace('-', '_'))
+ include_tags.add(value.replace("-", "_"))
+
pre_configure = []
post_configure = []
@@ -243,7 +317,8 @@ def _set_nomemory(opt, file_config):
def _monkeypatch_cdecimal(options, file_config):
if options.cdecimal:
import cdecimal
- sys.modules['decimal'] = cdecimal
+
+ sys.modules["decimal"] = cdecimal
@post
@@ -266,27 +341,28 @@ def _engine_uri(options, file_config):
if options.db:
for db_token in options.db:
- for db in re.split(r'[,\s]+', db_token):
- if db not in file_config.options('db'):
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
raise RuntimeError(
"Unknown URI specifier '%s'. "
- "Specify --dbs for known uris."
- % db)
+ "Specify --dbs for known uris." % db
+ )
else:
- db_urls.append(file_config.get('db', db))
+ db_urls.append(file_config.get("db", db))
if not db_urls:
- db_urls.append(file_config.get('db', 'default'))
+ db_urls.append(file_config.get("db", "default"))
config._current = None
for db_url in db_urls:
- if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
+ if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
with open(options.write_idents, "a") as file_:
file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
cfg = provision.setup_config(
- db_url, options, file_config, provision.FOLLOWER_IDENT)
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
if not config._current:
cfg.set_as_current(cfg, testing)
@@ -295,7 +371,7 @@ def _engine_uri(options, file_config):
@post
def _requirements(options, file_config):
- requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
@@ -334,22 +410,28 @@ def _prep_testing_database(options, file_config):
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData())
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(vname, schema.MetaData())
+ )
+ )
if config.requirements.schemas.enabled_for_config(cfg):
try:
- view_names = inspector.get_view_names(
- schema="test_schema")
+ view_names = inspector.get_view_names(schema="test_schema")
except NotImplementedError:
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData(),
- schema="test_schema")
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
util.drop_all_tables(e, inspector)
@@ -358,23 +440,29 @@ def _prep_testing_database(options, file_config):
if against(cfg, "postgresql"):
from sqlalchemy.dialects import postgresql
+
for enum in inspector.get_enums("*"):
- e.execute(postgresql.DropEnumType(
- postgresql.ENUM(
- name=enum['name'],
- schema=enum['schema'])))
+ e.execute(
+ postgresql.DropEnumType(
+ postgresql.ENUM(
+ name=enum["name"], schema=enum["schema"]
+ )
+ )
+ )
@post
def _reverse_topological(options, file_config):
if options.reversetop:
from sqlalchemy.orm.util import randomize_unitofwork
+
randomize_unitofwork()
@post
def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
+
config.options = options
config.file_config = file_config
@@ -382,17 +470,20 @@ def _post_setup_options(opt, file_config):
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
+
profiling._profile_stats = profiling.ProfileStatsFile(
- file_config.get('sqla_testing', 'profile_file'))
+ file_config.get("sqla_testing", "profile_file")
+ )
def want_class(cls):
if not issubclass(cls, fixtures.TestBase):
return False
- elif cls.__name__.startswith('_'):
+ elif cls.__name__.startswith("_"):
return False
- elif config.options.backend_only and not getattr(cls, '__backend__',
- False):
+ elif config.options.backend_only and not getattr(
+ cls, "__backend__", False
+ ):
return False
else:
return True
@@ -405,25 +496,28 @@ def want_method(cls, fn):
return False
elif include_tags:
return (
- hasattr(cls, '__tags__') and
- exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
) or (
- hasattr(fn, '_sa_exclusion_extend') and
- fn._sa_exclusion_extend.include_test(
- include_tags, exclude_tags)
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
)
- elif exclude_tags and hasattr(cls, '__tags__'):
+ elif exclude_tags and hasattr(cls, "__tags__"):
return exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
- elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
else:
return True
def generate_sub_tests(cls, module):
- if getattr(cls, '__backend__', False):
+ if getattr(cls, "__backend__", False):
for cfg in _possible_configs_for_cls(cls):
orig_name = cls.__name__
@@ -431,16 +525,13 @@ def generate_sub_tests(cls, module):
# pytest junit plugin, which is tripped up by the brackets
# and periods, so sanitize
- alpha_name = re.sub('[_\[\]\.]+', '_', cfg.name)
- alpha_name = re.sub('_+$', '', alpha_name)
+ alpha_name = re.sub("[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub("_+$", "", alpha_name)
name = "%s_%s" % (cls.__name__, alpha_name)
subcls = type(
name,
- (cls, ),
- {
- "_sa_orig_cls_name": orig_name,
- "__only_on_config__": cfg
- }
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
)
setattr(module, name, subcls)
yield subcls
@@ -454,8 +545,8 @@ def start_test_class(cls):
def stop_test_class(cls):
- #from sqlalchemy import inspect
- #assert not inspect(testing.db).get_table_names()
+ # from sqlalchemy import inspect
+ # assert not inspect(testing.db).get_table_names()
engines.testing_reaper._stop_test_ctx()
try:
if not options.low_connections:
@@ -475,7 +566,7 @@ def final_process_cleanup():
def _setup_engine(cls):
- if getattr(cls, '__engine_options__', None):
+ if getattr(cls, "__engine_options__", None):
eng = engines.testing_engine(options=cls.__engine_options__)
config._current.push_engine(eng, testing)
@@ -485,7 +576,7 @@ def before_test(test, test_module_name, test_class, test_name):
# like a nose id, e.g.:
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
- name = getattr(test_class, '_sa_orig_cls_name', test_class.__name__)
+ name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
@@ -505,16 +596,16 @@ def _possible_configs_for_cls(cls, reasons=None):
if spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on__', None):
+ if getattr(cls, "__only_on__", None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on_config__', None):
+ if getattr(cls, "__only_on_config__", None):
all_configs.intersection_update([cls.__only_on_config__])
- if hasattr(cls, '__requires__'):
+ if hasattr(cls, "__requires__"):
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__requires__:
@@ -527,7 +618,7 @@ def _possible_configs_for_cls(cls, reasons=None):
reasons.extend(skip_reasons)
break
- if hasattr(cls, '__prefer_requires__'):
+ if hasattr(cls, "__prefer_requires__"):
non_preferred = set()
requirements = config.requirements
for config_obj in list(all_configs):
@@ -546,30 +637,32 @@ def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
- if getattr(cls, '__skip_if__', False):
- for c in getattr(cls, '__skip_if__'):
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
if c():
- config.skip_test("'%s' skipped by %s" % (
- cls.__name__, c.__name__)
+ config.skip_test(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
)
if not all_configs:
msg = "'%s' unsupported on any DB implementation %s%s" % (
cls.__name__,
", ".join(
- "'%s(%s)+%s'" % (
+ "'%s(%s)+%s'"
+ % (
config_obj.db.name,
".".join(
- str(dig) for dig in
- exclusions._server_version(config_obj.db)),
- config_obj.db.driver
+ str(dig)
+ for dig in exclusions._server_version(config_obj.db)
+ ),
+ config_obj.db.driver,
)
- for config_obj in config.Config.all_configs()
+ for config_obj in config.Config.all_configs()
),
- ", ".join(reasons)
+ ", ".join(reasons),
)
config.skip_test(msg)
- elif hasattr(cls, '__prefer_backends__'):
+ elif hasattr(cls, "__prefer_backends__"):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
for config_obj in all_configs:
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index da682ea00..fd0a48462 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -13,6 +13,7 @@ import os
try:
import xdist # noqa
+
has_xdist = True
except ImportError:
has_xdist = False
@@ -24,30 +25,42 @@ def pytest_addoption(parser):
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
+
class CallableAction(argparse.Action):
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
callback_(option_string, values, parser)
+
kw["action"] = CallableAction
zeroarg_callback = kw.pop("zeroarg_callback", None)
if zeroarg_callback:
+
class CallableAction(argparse.Action):
- def __init__(self, option_strings,
- dest, default=False,
- required=False, help=None):
- super(CallableAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- const=True,
- default=default,
- required=required,
- help=help)
-
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None,
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
zeroarg_callback(option_string, values, parser)
+
kw["action"] = CallableAction
group.addoption(name, **kw)
@@ -59,18 +72,18 @@ def pytest_addoption(parser):
def pytest_configure(config):
if hasattr(config, "slaveinput"):
plugin_base.restore_important_follower_config(config.slaveinput)
- plugin_base.configure_follower(
- config.slaveinput["follower_ident"]
- )
+ plugin_base.configure_follower(config.slaveinput["follower_ident"])
else:
- if config.option.write_idents and \
- os.path.exists(config.option.write_idents):
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
os.remove(config.option.write_idents)
plugin_base.pre_begin(config.option)
- plugin_base.set_coverage_flag(bool(getattr(config.option,
- "cov_source", False)))
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
plugin_base.set_skip_test(pytest.skip.Exception)
@@ -94,10 +107,12 @@ if has_xdist:
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
from sqlalchemy.testing import provision
+
provision.create_follower_db(node.slaveinput["follower_ident"])
def pytest_testnodedown(node, error):
from sqlalchemy.testing import provision
+
provision.drop_follower_db(node.slaveinput["follower_ident"])
@@ -114,19 +129,22 @@ def pytest_collection_modifyitems(session, config, items):
rebuilt_items = collections.defaultdict(list)
items[:] = [
- item for item in
- items if isinstance(item.parent, pytest.Instance)
- and not item.parent.parent.name.startswith("_")]
+ item
+ for item in items
+ if isinstance(item.parent, pytest.Instance)
+ and not item.parent.parent.name.startswith("_")
+ ]
test_classes = set(item.parent for item in items)
for test_class in test_classes:
for sub_cls in plugin_base.generate_sub_tests(
- test_class.cls, test_class.parent.module):
+ test_class.cls, test_class.parent.module
+ ):
if sub_cls is not test_class.cls:
list_ = rebuilt_items[test_class.cls]
for inst in pytest.Class(
- sub_cls.__name__,
- parent=test_class.parent.parent).collect():
+ sub_cls.__name__, parent=test_class.parent.parent
+ ).collect():
list_.extend(inst.collect())
newitems = []
@@ -139,23 +157,29 @@ def pytest_collection_modifyitems(session, config, items):
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
- items[:] = sorted(newitems, key=lambda item: (
- item.parent.parent.parent.name,
- item.parent.parent.name,
- item.name
- ))
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.parent.parent.parent.name,
+ item.parent.parent.name,
+ item.name,
+ ),
+ )
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(obj):
return pytest.Class(name, parent=collector)
- elif inspect.isfunction(obj) and \
- isinstance(collector, pytest.Instance) and \
- plugin_base.want_method(collector.cls, obj):
+ elif (
+ inspect.isfunction(obj)
+ and isinstance(collector, pytest.Instance)
+ and plugin_base.want_method(collector.cls, obj)
+ ):
return pytest.Function(name, parent=collector)
else:
return []
+
_current_class = None
@@ -180,6 +204,7 @@ def pytest_runtest_setup(item):
global _current_class
class_teardown(item.parent.parent)
_current_class = None
+
item.parent.parent.addfinalizer(finalize)
test_setup(item)
@@ -194,8 +219,9 @@ def pytest_runtest_teardown(item):
def test_setup(item):
- plugin_base.before_test(item, item.parent.module.__name__,
- item.parent.cls, item.name)
+ plugin_base.before_test(
+ item, item.parent.module.__name__, item.parent.cls, item.name
+ )
def test_teardown(item):
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
index fab99b186..3986985c7 100644
--- a/lib/sqlalchemy/testing/profiling.py
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -42,17 +42,16 @@ class ProfileStatsFile(object):
def __init__(self, filename):
self.force_write = (
- config.options is not None and
- config.options.force_write_profiles
+ config.options is not None and config.options.force_write_profiles
)
self.write = self.force_write or (
- config.options is not None and
- config.options.write_profiles
+ config.options is not None and config.options.write_profiles
)
self.fname = os.path.abspath(filename)
self.short_fname = os.path.split(self.fname)[-1]
self.data = collections.defaultdict(
- lambda: collections.defaultdict(dict))
+ lambda: collections.defaultdict(dict)
+ )
self._read()
if self.write:
# rewrite for the case where features changed,
@@ -65,7 +64,7 @@ class ProfileStatsFile(object):
dbapi_key = config.db.name + "_" + config.db.driver
# keep it at 2.7, 3.1, 3.2, etc. for now.
- py_version = '.'.join([str(v) for v in sys.version_info[0:2]])
+ py_version = ".".join([str(v) for v in sys.version_info[0:2]])
platform_tokens = [py_version]
platform_tokens.append(dbapi_key)
@@ -87,8 +86,7 @@ class ProfileStatsFile(object):
def has_stats(self):
test_key = _current_test
return (
- test_key in self.data and
- self.platform_key in self.data[test_key]
+ test_key in self.data and self.platform_key in self.data[test_key]
)
def result(self, callcount):
@@ -96,15 +94,15 @@ class ProfileStatsFile(object):
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
- if 'counts' not in per_platform:
- per_platform['counts'] = counts = []
+ if "counts" not in per_platform:
+ per_platform["counts"] = counts = []
else:
- counts = per_platform['counts']
+ counts = per_platform["counts"]
- if 'current_count' not in per_platform:
- per_platform['current_count'] = current_count = 0
+ if "current_count" not in per_platform:
+ per_platform["current_count"] = current_count = 0
else:
- current_count = per_platform['current_count']
+ current_count = per_platform["current_count"]
has_count = len(counts) > current_count
@@ -114,16 +112,16 @@ class ProfileStatsFile(object):
self._write()
result = None
else:
- result = per_platform['lineno'], counts[current_count]
- per_platform['current_count'] += 1
+ result = per_platform["lineno"], counts[current_count]
+ per_platform["current_count"] += 1
return result
def replace(self, callcount):
test_key = _current_test
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
- counts = per_platform['counts']
- current_count = per_platform['current_count']
+ counts = per_platform["counts"]
+ current_count = per_platform["current_count"]
if current_count < len(counts):
counts[current_count - 1] = callcount
else:
@@ -164,9 +162,9 @@ class ProfileStatsFile(object):
per_fn = self.data[test_key]
per_platform = per_fn[platform_key]
c = [int(count) for count in counts.split(",")]
- per_platform['counts'] = c
- per_platform['lineno'] = lineno + 1
- per_platform['current_count'] = 0
+ per_platform["counts"] = c
+ per_platform["lineno"] = lineno + 1
+ per_platform["current_count"] = 0
profile_f.close()
def _write(self):
@@ -179,7 +177,7 @@ class ProfileStatsFile(object):
profile_f.write("\n# TEST: %s\n\n" % test_key)
for platform_key in sorted(per_fn):
per_platform = per_fn[platform_key]
- c = ",".join(str(count) for count in per_platform['counts'])
+ c = ",".join(str(count) for count in per_platform["counts"])
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
profile_f.close()
@@ -199,7 +197,9 @@ def function_call_count(variance=0.05):
def wrap(*args, **kw):
with count_functions(variance=variance):
return fn(*args, **kw)
+
return update_wrapper(wrap, fn)
+
return decorate
@@ -213,21 +213,22 @@ def count_functions(variance=0.05):
"No profiling stats available on this "
"platform for this function. Run tests with "
"--write-profiles to add statistics to %s for "
- "this platform." % _profile_stats.short_fname)
+ "this platform." % _profile_stats.short_fname
+ )
gc_collect()
pr = cProfile.Profile()
pr.enable()
- #began = time.time()
+ # began = time.time()
yield
- #ended = time.time()
+ # ended = time.time()
pr.disable()
- #s = compat.StringIO()
+ # s = compat.StringIO()
stats = pstats.Stats(pr, stream=sys.stdout)
- #timespent = ended - began
+ # timespent = ended - began
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
@@ -237,11 +238,7 @@ def count_functions(variance=0.05):
else:
line_no, expected_count = expected
- print(("Pstats calls: %d Expected %s" % (
- callcount,
- expected_count
- )
- ))
+ print(("Pstats calls: %d Expected %s" % (callcount, expected_count)))
stats.sort_stats("cumulative")
stats.print_stats()
@@ -259,7 +256,9 @@ def count_functions(variance=0.05):
"--write-profiles to "
"regenerate this callcount."
% (
- callcount, (variance * 100),
- expected_count, _profile_stats.platform_key))
-
-
+ callcount,
+ (variance * 100),
+ expected_count,
+ _profile_stats.platform_key,
+ )
+ )
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index c0ca7c1cb..25028ccb3 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -8,6 +8,7 @@ import collections
import os
import time
import logging
+
log = logging.getLogger(__name__)
FOLLOWER_IDENT = None
@@ -25,6 +26,7 @@ class register(object):
def decorate(fn):
self.fns[dbname] = fn
return self
+
return decorate
def __call__(self, cfg, *arg):
@@ -38,7 +40,7 @@ class register(object):
if backend in self.fns:
return self.fns[backend](cfg, *arg)
else:
- return self.fns['*'](cfg, *arg)
+ return self.fns["*"](cfg, *arg)
def create_follower_db(follower_ident):
@@ -82,9 +84,7 @@ def _configs_for_db_operation():
for cfg in config.Config.all_configs():
url = cfg.db.url
backend = url.get_backend_name()
- host_conf = (
- backend,
- url.username, url.host, url.database)
+ host_conf = (backend, url.username, url.host, url.database)
if host_conf not in hosts:
yield cfg
@@ -128,14 +128,13 @@ def _follower_url_from_main(url, ident):
@_update_db_opts.for_db("mssql")
def _mssql_update_db_opts(db_url, db_opts):
- db_opts['legacy_schema_aliasing'] = False
-
+ db_opts["legacy_schema_aliasing"] = False
@_follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
- if not url.database or url.database == ':memory:':
+ if not url.database or url.database == ":memory:":
return url
else:
return sa_url.make_url("sqlite:///%s.db" % ident)
@@ -151,19 +150,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident):
# as an attached
if not follower_ident:
dbapi_connection.execute(
- 'ATTACH DATABASE "test_schema.db" AS test_schema')
+ 'ATTACH DATABASE "test_schema.db" AS test_schema'
+ )
else:
dbapi_connection.execute(
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
- % follower_ident)
+ % follower_ident
+ )
@_create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
template_db = cfg.options.postgresql_templatedb
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
try:
_pg_drop_db(cfg, conn, ident)
except Exception:
@@ -175,7 +175,8 @@ def _pg_create_db(cfg, eng, ident):
while True:
try:
conn.execute(
- "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db))
+ "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
+ )
except exc.OperationalError as err:
attempt += 1
if attempt >= 3:
@@ -184,8 +185,11 @@ def _pg_create_db(cfg, eng, ident):
log.info(
"Waiting to create %s, URI %r, "
"template DB %s is in use sleeping for .5",
- ident, eng.url, template_db)
- time.sleep(.5)
+ ident,
+ eng.url,
+ template_db,
+ )
+ time.sleep(0.5)
else:
break
@@ -203,9 +207,11 @@ def _mysql_create_db(cfg, eng, ident):
# 1271, u"Illegal mix of collations for operation 'UNION'"
conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident)
conn.execute(
- "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident)
+ "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident
+ )
conn.execute(
- "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident)
+ "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident
+ )
@_configure_follower.for_db("mysql")
@@ -221,14 +227,15 @@ def _sqlite_create_db(cfg, eng, ident):
@_drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.execute(
text(
"select pg_terminate_backend(pid) from pg_stat_activity "
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
- ), dname=ident)
+ ),
+ dname=ident,
+ )
conn.execute("DROP DATABASE %s" % ident)
@@ -257,11 +264,12 @@ def _oracle_create_db(cfg, eng, ident):
conn.execute("create user %s identified by xe" % ident)
conn.execute("create user %s_ts1 identified by xe" % ident)
conn.execute("create user %s_ts2 identified by xe" % ident)
- conn.execute("grant dba to %s" % (ident, ))
+ conn.execute("grant dba to %s" % (ident,))
conn.execute("grant unlimited tablespace to %s" % ident)
conn.execute("grant unlimited tablespace to %s_ts1" % ident)
conn.execute("grant unlimited tablespace to %s_ts2" % ident)
+
@_configure_follower.for_db("oracle")
def _oracle_configure_follower(config, ident):
config.test_schema = "%s_ts1" % ident
@@ -320,6 +328,7 @@ def reap_dbs(idents_file):
elif backend == "mssql":
_reap_mssql_dbs(url, ident)
+
def _reap_oracle_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
@@ -330,8 +339,9 @@ def _reap_oracle_dbs(url, idents):
to_reap = conn.execute(
"select u.username from all_users u where username "
"like 'TEST_%' and not exists (select username "
- "from v$session where username=u.username)")
- all_names = {username.lower() for (username, ) in to_reap}
+ "from v$session where username=u.username)"
+ )
+ all_names = {username.lower() for (username,) in to_reap}
to_drop = set()
for name in all_names:
if name.endswith("_ts1") or name.endswith("_ts2"):
@@ -348,28 +358,28 @@ def _reap_oracle_dbs(url, idents):
if _ora_drop_ignore(conn, username):
dropped += 1
log.info(
- "Dropped %d out of %d stale databases detected",
- dropped, total)
-
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
@_follower_url_from_main.for_db("oracle")
def _oracle_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
url.username = ident
- url.password = 'xe'
+ url.password = "xe"
return url
@_create_db.for_db("mssql")
def _mssql_create_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.execute("create database %s" % ident)
conn.execute(
- "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident)
+ "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
+ )
conn.execute(
- "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident)
+ "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
+ )
conn.execute("use %s" % ident)
conn.execute("create schema test_schema")
conn.execute("create schema test_schema_2")
@@ -377,10 +387,10 @@ def _mssql_create_db(cfg, eng, ident):
@_drop_db.for_db("mssql")
def _mssql_drop_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
_mssql_drop_ignore(conn, ident)
+
def _mssql_drop_ignore(conn, ident):
try:
# typically when this happens, we can't KILL the session anyway,
@@ -401,8 +411,7 @@ def _mssql_drop_ignore(conn, ident):
def _reap_mssql_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
log.info("identifiers in file: %s", ", ".join(idents))
@@ -410,8 +419,9 @@ def _reap_mssql_dbs(url, idents):
"select d.name from sys.databases as d where name "
"like 'TEST_%' and not exists (select session_id "
"from sys.dm_exec_sessions "
- "where database_id=d.database_id)")
- all_names = {dbname.lower() for (dbname, ) in to_reap}
+ "where database_id=d.database_id)"
+ )
+ all_names = {dbname.lower() for (dbname,) in to_reap}
to_drop = set()
for name in all_names:
if name in idents:
@@ -422,5 +432,5 @@ def _reap_mssql_dbs(url, idents):
if _mssql_drop_ignore(conn, dbname):
dropped += 1
log.info(
- "Dropped %d out of %d stale databases detected",
- dropped, total)
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
diff --git a/lib/sqlalchemy/testing/replay_fixture.py b/lib/sqlalchemy/testing/replay_fixture.py
index b50f52e3d..9832b07a2 100644
--- a/lib/sqlalchemy/testing/replay_fixture.py
+++ b/lib/sqlalchemy/testing/replay_fixture.py
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
class ReplayFixtureTest(fixtures.TestBase):
-
@contextlib.contextmanager
def _dummy_ctx(self, *arg, **kw):
yield
@@ -22,8 +21,8 @@ class ReplayFixtureTest(fixtures.TestBase):
creator = config.db.pool._creator
recorder = lambda: dbapi_session.recorder(creator())
engine = create_engine(
- config.db.url, creator=recorder,
- use_native_hstore=False)
+ config.db.url, creator=recorder, use_native_hstore=False
+ )
self.metadata = MetaData(engine)
self.engine = engine
self.session = Session(engine)
@@ -37,8 +36,8 @@ class ReplayFixtureTest(fixtures.TestBase):
player = lambda: dbapi_session.player()
engine = create_engine(
- config.db.url, creator=player,
- use_native_hstore=False)
+ config.db.url, creator=player, use_native_hstore=False
+ )
self.metadata = MetaData(engine)
self.engine = engine
@@ -74,21 +73,49 @@ class ReplayableSession(object):
NoAttribute = object()
if util.py2k:
- Natives = set([getattr(types, t)
- for t in dir(types) if not t.startswith('_')]).\
- difference([getattr(types, t)
- for t in ('FunctionType', 'BuiltinFunctionType',
- 'MethodType', 'BuiltinMethodType',
- 'LambdaType', 'UnboundMethodType',)])
+ Natives = set(
+ [getattr(types, t) for t in dir(types) if not t.startswith("_")]
+ ).difference(
+ [
+ getattr(types, t)
+ for t in (
+ "FunctionType",
+ "BuiltinFunctionType",
+ "MethodType",
+ "BuiltinMethodType",
+ "LambdaType",
+ "UnboundMethodType",
+ )
+ ]
+ )
else:
- Natives = set([getattr(types, t)
- for t in dir(types) if not t.startswith('_')]).\
- union([type(t) if not isinstance(t, type)
- else t for t in __builtins__.values()]).\
- difference([getattr(types, t)
- for t in ('FunctionType', 'BuiltinFunctionType',
- 'MethodType', 'BuiltinMethodType',
- 'LambdaType', )])
+ Natives = (
+ set(
+ [
+ getattr(types, t)
+ for t in dir(types)
+ if not t.startswith("_")
+ ]
+ )
+ .union(
+ [
+ type(t) if not isinstance(t, type) else t
+ for t in __builtins__.values()
+ ]
+ )
+ .difference(
+ [
+ getattr(types, t)
+ for t in (
+ "FunctionType",
+ "BuiltinFunctionType",
+ "MethodType",
+ "BuiltinMethodType",
+ "LambdaType",
+ )
+ ]
+ )
+ )
def __init__(self):
self.buffer = deque()
@@ -105,8 +132,10 @@ class ReplayableSession(object):
self._subject = subject
def __call__(self, *args, **kw):
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
+ subject, buffer = [
+ object.__getattribute__(self, x)
+ for x in ("_subject", "_buffer")
+ ]
result = subject(*args, **kw)
if type(result) not in ReplayableSession.Natives:
@@ -126,8 +155,10 @@ class ReplayableSession(object):
except AttributeError:
pass
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
+ subject, buffer = [
+ object.__getattribute__(self, x)
+ for x in ("_subject", "_buffer")
+ ]
try:
result = type(subject).__getattribute__(subject, key)
except AttributeError:
@@ -146,7 +177,7 @@ class ReplayableSession(object):
self._buffer = buffer
def __call__(self, *args, **kw):
- buffer = object.__getattribute__(self, '_buffer')
+ buffer = object.__getattribute__(self, "_buffer")
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
@@ -162,7 +193,7 @@ class ReplayableSession(object):
return object.__getattribute__(self, key)
except AttributeError:
pass
- buffer = object.__getattribute__(self, '_buffer')
+ buffer = object.__getattribute__(self, "_buffer")
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 58df643f4..c96d26d32 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -26,7 +26,6 @@ class Requirements(object):
class SuiteRequirements(Requirements):
-
@property
def create_table(self):
"""target platform can emit basic CreateTable DDL."""
@@ -68,8 +67,8 @@ class SuiteRequirements(Requirements):
# somehow only_if([x, y]) isn't working here, negation/conjunctions
# getting confused.
return exclusions.only_if(
- lambda: self.on_update_cascade.enabled or
- self.deferrable_fks.enabled
+ lambda: self.on_update_cascade.enabled
+ or self.deferrable_fks.enabled
)
@property
@@ -231,22 +230,21 @@ class SuiteRequirements(Requirements):
def sane_rowcount(self):
return exclusions.skip_if(
lambda config: not config.db.dialect.supports_sane_rowcount,
- "driver doesn't support 'sane' rowcount"
+ "driver doesn't support 'sane' rowcount",
)
@property
def sane_multi_rowcount(self):
return exclusions.fails_if(
lambda config: not config.db.dialect.supports_sane_multi_rowcount,
- "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
+ "driver %(driver)s %(doesnt_support)s 'sane' multi row count",
)
@property
def sane_rowcount_w_returning(self):
return exclusions.fails_if(
- lambda config:
- not config.db.dialect.supports_sane_rowcount_returning,
- "driver doesn't support 'sane' rowcount when returning is on"
+ lambda config: not config.db.dialect.supports_sane_rowcount_returning,
+ "driver doesn't support 'sane' rowcount when returning is on",
)
@property
@@ -255,9 +253,9 @@ class SuiteRequirements(Requirements):
INSERT DEFAULT VALUES or equivalent."""
return exclusions.only_if(
- lambda config: config.db.dialect.supports_empty_insert or
- config.db.dialect.supports_default_values,
- "empty inserts not supported"
+ lambda config: config.db.dialect.supports_empty_insert
+ or config.db.dialect.supports_default_values,
+ "empty inserts not supported",
)
@property
@@ -272,7 +270,7 @@ class SuiteRequirements(Requirements):
return exclusions.only_if(
lambda config: config.db.dialect.implicit_returning,
- "%(database)s %(does_support)s 'returning'"
+ "%(database)s %(does_support)s 'returning'",
)
@property
@@ -297,7 +295,7 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(
lambda config: not config.db.dialect.requires_name_normalize,
- "Backend does not require denormalized names."
+ "Backend does not require denormalized names.",
)
@property
@@ -307,7 +305,7 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(
lambda config: not config.db.dialect.supports_multivalues_insert,
- "Backend does not support multirow inserts."
+ "Backend does not support multirow inserts.",
)
@property
@@ -355,27 +353,32 @@ class SuiteRequirements(Requirements):
def server_side_cursors(self):
"""Target dialect must support server side cursors."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_server_side_cursors
- ], "no server side cursors support")
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_server_side_cursors],
+ "no server side cursors support",
+ )
@property
def sequences(self):
"""Target database must support SEQUENCEs."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_sequences
- ], "no sequence support")
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_sequences],
+ "no sequence support",
+ )
@property
def sequences_optional(self):
"""Target database supports sequences, but also optionally
as a means of generating new PK values."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_sequences and
- config.db.dialect.sequences_optional
- ], "no sequence support, or sequences not optional")
+ return exclusions.only_if(
+ [
+ lambda config: config.db.dialect.supports_sequences
+ and config.db.dialect.sequences_optional
+ ],
+ "no sequence support, or sequences not optional",
+ )
@property
def reflects_pk_names(self):
@@ -841,7 +844,8 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
- lambda config: config.options.low_connections)
+ lambda config: config.options.low_connections
+ )
@property
def timing_intensive(self):
@@ -859,37 +863,37 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
lambda config: util.py3k and config.options.has_coverage,
- "Stability issues with coverage + py3k"
+ "Stability issues with coverage + py3k",
)
@property
def python2(self):
return exclusions.skip_if(
lambda: sys.version_info >= (3,),
- "Python version 2.xx is required."
+ "Python version 2.xx is required.",
)
@property
def python3(self):
return exclusions.skip_if(
- lambda: sys.version_info < (3,),
- "Python version 3.xx is required."
+ lambda: sys.version_info < (3,), "Python version 3.xx is required."
)
@property
def cpython(self):
return exclusions.only_if(
- lambda: util.cpython,
- "cPython interpreter needed"
+ lambda: util.cpython, "cPython interpreter needed"
)
@property
def non_broken_pickle(self):
from sqlalchemy.util import pickle
+
return exclusions.only_if(
- lambda: not util.pypy and pickle.__name__ == 'cPickle'
- or sys.version_info >= (3, 2),
- "Needs cPickle+cPython or newer Python 3 pickle"
+ lambda: not util.pypy
+ and pickle.__name__ == "cPickle"
+ or sys.version_info >= (3, 2),
+ "Needs cPickle+cPython or newer Python 3 pickle",
)
@property
@@ -910,7 +914,7 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
lambda config: config.options.has_coverage,
- "Issues observed when coverage is enabled"
+ "Issues observed when coverage is enabled",
)
def _has_mysql_on_windows(self, config):
@@ -931,8 +935,9 @@ class SuiteRequirements(Requirements):
def _has_sqlite(self):
from sqlalchemy import create_engine
+
try:
- create_engine('sqlite://')
+ create_engine("sqlite://")
return True
except ImportError:
return False
@@ -940,6 +945,7 @@ class SuiteRequirements(Requirements):
def _has_cextensions(self):
try:
from sqlalchemy import cresultproxy, cprocessors
+
return True
except ImportError:
return False
diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py
index 87be0749c..6aa820fd5 100644
--- a/lib/sqlalchemy/testing/runner.py
+++ b/lib/sqlalchemy/testing/runner.py
@@ -47,4 +47,4 @@ def setup_py_test():
to nose.
"""
- nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner'])
+ nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"])
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
index 401c8cbb7..b345a9487 100644
--- a/lib/sqlalchemy/testing/schema.py
+++ b/lib/sqlalchemy/testing/schema.py
@@ -9,7 +9,7 @@ from . import exclusions
from .. import schema, event
from . import config
-__all__ = 'Table', 'Column',
+__all__ = "Table", "Column"
table_options = {}
@@ -17,30 +17,35 @@ table_options = {}
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
- test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')}
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
kw.update(table_options)
- if exclusions.against(config._current, 'mysql'):
- if 'mysql_engine' not in kw and 'mysql_type' not in kw and \
- "autoload_with" not in kw:
- if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
- kw['mysql_engine'] = 'InnoDB'
+ if exclusions.against(config._current, "mysql"):
+ if (
+ "mysql_engine" not in kw
+ and "mysql_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mysql_engine"] = "InnoDB"
else:
- kw['mysql_engine'] = 'MyISAM'
+ kw["mysql_engine"] = "MyISAM"
# Apply some default cascading rules for self-referential foreign keys.
# MySQL InnoDB has some issues around seleting self-refs too.
- if exclusions.against(config._current, 'firebird'):
+ if exclusions.against(config._current, "firebird"):
table_name = args[0]
- unpack = (config.db.dialect.
- identifier_preparer.unformat_identifiers)
+ unpack = config.db.dialect.identifier_preparer.unformat_identifiers
# Only going after ForeignKeys in Columns. May need to
# expand to ForeignKeyConstraint too.
- fks = [fk
- for col in args if isinstance(col, schema.Column)
- for fk in col.foreign_keys]
+ fks = [
+ fk
+ for col in args
+ if isinstance(col, schema.Column)
+ for fk in col.foreign_keys
+ ]
for fk in fks:
# root around in raw spec
@@ -54,9 +59,9 @@ def Table(*args, **kw):
name = unpack(ref)[0]
if name == table_name:
if fk.ondelete is None:
- fk.ondelete = 'CASCADE'
+ fk.ondelete = "CASCADE"
if fk.onupdate is None:
- fk.onupdate = 'CASCADE'
+ fk.onupdate = "CASCADE"
return schema.Table(*args, **kw)
@@ -64,37 +69,46 @@ def Table(*args, **kw):
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
- test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')}
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
col = schema.Column(*args, **kw)
- if test_opts.get('test_needs_autoincrement', False) and \
- kw.get('primary_key', False):
+ if test_opts.get("test_needs_autoincrement", False) and kw.get(
+ "primary_key", False
+ ):
if col.default is None and col.server_default is None:
col.autoincrement = True
# allow any test suite to pick up on this
- col.info['test_needs_autoincrement'] = True
+ col.info["test_needs_autoincrement"] = True
# hardcoded rule for firebird, oracle; this should
# be moved out
- if exclusions.against(config._current, 'firebird', 'oracle'):
+ if exclusions.against(config._current, "firebird", "oracle"):
+
def add_seq(c, tbl):
c._init_items(
- schema.Sequence(_truncate_name(
- config.db.dialect, tbl.name + '_' + c.name + '_seq'),
- optional=True)
+ schema.Sequence(
+ _truncate_name(
+ config.db.dialect, tbl.name + "_" + c.name + "_seq"
+ ),
+ optional=True,
+ )
)
- event.listen(col, 'after_parent_attach', add_seq, propagate=True)
+
+ event.listen(col, "after_parent_attach", add_seq, propagate=True)
return col
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
- return name[0:max(dialect.max_identifier_length - 6, 0)] + \
- "_" + hex(hash(name) % 64)[2:]
+ return (
+ name[0 : max(dialect.max_identifier_length - 6, 0)]
+ + "_"
+ + hex(hash(name) % 64)[2:]
+ )
else:
return name
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
index 748d9722d..a4e142c5a 100644
--- a/lib/sqlalchemy/testing/suite/__init__.py
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -1,4 +1,3 @@
-
from sqlalchemy.testing.suite.test_cte import *
from sqlalchemy.testing.suite.test_dialect import *
from sqlalchemy.testing.suite.test_ddl import *
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
index cc72278e6..d2f35933b 100644
--- a/lib/sqlalchemy/testing/suite/test_cte.py
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -10,22 +10,28 @@ from ..schema import Table, Column
class CTETest(fixtures.TablesTest):
__backend__ = True
- __requires__ = 'ctes',
+ __requires__ = ("ctes",)
- run_inserts = 'each'
- run_deletes = 'each'
+ run_inserts = "each"
+ run_deletes = "each"
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column("parent_id", ForeignKey("some_table.id")))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", ForeignKey("some_table.id")),
+ )
- Table("some_other_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column("parent_id", Integer))
+ Table(
+ "some_other_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -36,28 +42,33 @@ class CTETest(fixtures.TablesTest):
{"id": 2, "data": "d2", "parent_id": 1},
{"id": 3, "data": "d3", "parent_id": 1},
{"id": 4, "data": "d4", "parent_id": 3},
- {"id": 5, "data": "d5", "parent_id": 3}
- ]
+ {"id": 5, "data": "d5", "parent_id": 3},
+ ],
)
def test_select_nonrecursive_round_trip(self):
some_table = self.tables.some_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
result = conn.execute(
select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
)
- eq_(result.fetchall(), [("d4", )])
+ eq_(result.fetchall(), [("d4",)])
def test_select_recursive_round_trip(self):
some_table = self.tables.some_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])).cte(
- "some_cte", recursive=True)
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte", recursive=True)
+ )
cte_alias = cte.alias("c1")
st1 = some_table.alias()
@@ -67,12 +78,13 @@ class CTETest(fixtures.TablesTest):
select([st1]).where(st1.c.id == cte_alias.c.parent_id)
)
result = conn.execute(
- select([cte.c.data]).where(
- cte.c.data != "d2").order_by(cte.c.data.desc())
+ select([cte.c.data])
+ .where(cte.c.data != "d2")
+ .order_by(cte.c.data.desc())
)
eq_(
result.fetchall(),
- [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
+ [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
)
def test_insert_from_select_round_trip(self):
@@ -80,20 +92,21 @@ class CTETest(fixtures.TablesTest):
some_other_table = self.tables.some_other_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.insert().from_select(
- ["id", "data", "parent_id"],
- select([cte])
+ ["id", "data", "parent_id"], select([cte])
)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
)
@testing.requires.ctes_with_update_delete
@@ -105,27 +118,31 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
- some_other_table.update().values(parent_id=5).where(
- some_other_table.c.data == cte.c.data
- )
+ some_other_table.update()
+ .values(parent_id=5)
+ .where(some_other_table.c.data == cte.c.data)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
[
- (1, "d1", None), (2, "d2", 5),
- (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
- ]
+ (1, "d1", None),
+ (2, "d2", 5),
+ (3, "d3", 5),
+ (4, "d4", 5),
+ (5, "d5", 3),
+ ],
)
@testing.requires.ctes_with_update_delete
@@ -137,14 +154,15 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.delete().where(
some_other_table.c.data == cte.c.data
@@ -154,9 +172,7 @@ class CTETest(fixtures.TablesTest):
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [
- (1, "d1", None), (5, "d5", 3)
- ]
+ [(1, "d1", None), (5, "d5", 3)],
)
@testing.requires.ctes_with_update_delete
@@ -168,26 +184,26 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.delete().where(
- some_other_table.c.data ==
- select([cte.c.data]).where(
- cte.c.id == some_other_table.c.id)
+ some_other_table.c.data
+ == select([cte.c.data]).where(
+ cte.c.id == some_other_table.c.id
+ )
)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [
- (1, "d1", None), (5, "d5", 3)
- ]
+ [(1, "d1", None), (5, "d5", 3)],
)
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
index 1d8010c8a..7c44388d4 100644
--- a/lib/sqlalchemy/testing/suite/test_ddl.py
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -1,5 +1,3 @@
-
-
from .. import fixtures, config, util
from ..config import requirements
from ..assertions import eq_
@@ -11,55 +9,47 @@ class TableDDLTest(fixtures.TestBase):
__backend__ = True
def _simple_fixture(self):
- return Table('test_table', self.metadata,
- Column('id', Integer, primary_key=True,
- autoincrement=False),
- Column('data', String(50))
- )
+ return Table(
+ "test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
def _underscore_fixture(self):
- return Table('_test_table', self.metadata,
- Column('id', Integer, primary_key=True,
- autoincrement=False),
- Column('_data', String(50))
- )
+ return Table(
+ "_test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("_data", String(50)),
+ )
def _simple_roundtrip(self, table):
with config.db.begin() as conn:
- conn.execute(table.insert().values((1, 'some data')))
+ conn.execute(table.insert().values((1, "some data")))
result = conn.execute(table.select())
- eq_(
- result.first(),
- (1, 'some data')
- )
+ eq_(result.first(), (1, "some data"))
@requirements.create_table
@util.provide_metadata
def test_create_table(self):
table = self._simple_fixture()
- table.create(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.drop_table
@util.provide_metadata
def test_drop_table(self):
table = self._simple_fixture()
- table.create(
- config.db, checkfirst=False
- )
- table.drop(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
+ table.drop(config.db, checkfirst=False)
@requirements.create_table
@util.provide_metadata
def test_underscore_names(self):
table = self._underscore_fixture()
- table.create(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
-__all__ = ('TableDDLTest', )
+
+__all__ = ("TableDDLTest",)
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
index 2c5dd0e36..5e589f3b8 100644
--- a/lib/sqlalchemy/testing/suite/test_dialect.py
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -15,16 +15,19 @@ class ExceptionTest(fixtures.TablesTest):
specific exceptions from real round trips, we need to be conservative.
"""
- run_deletes = 'each'
+
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
@requirements.duplicate_key_raises_integrity_error
def test_integrity_error(self):
@@ -33,15 +36,14 @@ class ExceptionTest(fixtures.TablesTest):
trans = conn.begin()
conn.execute(
- self.tables.manual_pk.insert(),
- {'id': 1, 'data': 'd1'}
+ self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
)
assert_raises(
exc.IntegrityError,
conn.execute,
self.tables.manual_pk.insert(),
- {'id': 1, 'data': 'd1'}
+ {"id": 1, "data": "d1"},
)
trans.rollback()
@@ -49,38 +51,39 @@ class ExceptionTest(fixtures.TablesTest):
class AutocommitTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
- __requires__ = 'autocommit',
+ __requires__ = ("autocommit",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('some_table', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50)),
- test_needs_acid=True
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ test_needs_acid=True,
+ )
def _test_conn_autocommits(self, conn, autocommit):
trans = conn.begin()
conn.execute(
- self.tables.some_table.insert(),
- {"id": 1, "data": "some data"}
+ self.tables.some_table.insert(), {"id": 1, "data": "some data"}
)
trans.rollback()
eq_(
conn.scalar(select([self.tables.some_table.c.id])),
- 1 if autocommit else None
+ 1 if autocommit else None,
)
conn.execute(self.tables.some_table.delete())
def test_autocommit_on(self):
conn = config.db.connect()
- c2 = conn.execution_options(isolation_level='AUTOCOMMIT')
+ c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(c2, True)
conn.invalidate()
self._test_conn_autocommits(conn, False)
@@ -98,7 +101,7 @@ class EscapingTest(fixtures.TestBase):
"""
m = self.metadata
- t = Table('t', m, Column('data', String(50)))
+ t = Table("t", m, Column("data", String(50)))
t.create(config.db)
with config.db.begin() as conn:
conn.execute(t.insert(), dict(data="some % value"))
@@ -107,14 +110,17 @@ class EscapingTest(fixtures.TestBase):
eq_(
conn.scalar(
select([t.c.data]).where(
- t.c.data == literal_column("'some % value'"))
+ t.c.data == literal_column("'some % value'")
+ )
),
- "some % value"
+ "some % value",
)
eq_(
conn.scalar(
select([t.c.data]).where(
- t.c.data == literal_column("'some %% other value'"))
- ), "some %% other value"
+ t.c.data == literal_column("'some %% other value'")
+ )
+ ),
+ "some %% other value",
)
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index c0b6b18eb..6257451eb 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -10,53 +10,48 @@ from ..schema import Table, Column
class LastrowidTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
- __requires__ = 'implements_get_lastrowid', 'autoincrement_insert'
+ __requires__ = "implements_get_lastrowid", "autoincrement_insert"
__engine_options__ = {"implicit_returning": False}
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (config.db.dialect.default_sequence_base, "some data")
- )
+ eq_(row, (config.db.dialect.default_sequence_base, "some data"))
def test_autoincrement_on_insert(self):
- config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.autoinc_pk.insert(), data="some data")
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- r.inserted_primary_key,
- [pk]
- )
+ eq_(r.inserted_primary_key, [pk])
# failed on pypy1.9 but seems to be OK on pypy 2.1
# @exclusions.fails_if(lambda: util.pypy,
@@ -65,50 +60,57 @@ class LastrowidTest(fixtures.TablesTest):
@requirements.dbapi_lastrowid
def test_native_lastrowid_autoinc(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
lastrowid = r.lastrowid
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- lastrowid, pk
- )
+ eq_(lastrowid, pk)
class InsertBehaviorTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
- Table('includes_defaults', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50)),
- Column('x', Integer, default=5),
- Column('y', Integer,
- default=literal_column("2", type_=Integer) + literal(2)))
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+ Table(
+ "includes_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ Column("x", Integer, default=5),
+ Column(
+ "y",
+ Integer,
+ default=literal_column("2", type_=Integer) + literal(2),
+ ),
+ )
def test_autoclose_on_insert(self):
if requirements.returning.enabled:
engine = engines.testing_engine(
- options={'implicit_returning': False})
+ options={"implicit_returning": False}
+ )
else:
engine = config.db
- r = engine.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ r = engine.execute(self.tables.autoinc_pk.insert(), data="some data")
assert r._soft_closed
assert not r.closed
assert r.is_insert
@@ -117,8 +119,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
@requirements.returning
def test_autoclose_on_insert_implicit_returning(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
assert r._soft_closed
assert not r.closed
@@ -127,15 +128,14 @@ class InsertBehaviorTest(fixtures.TablesTest):
@requirements.empty_inserts
def test_empty_insert(self):
- r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- )
+ r = config.db.execute(self.tables.autoinc_pk.insert())
assert r._soft_closed
assert not r.closed
r = config.db.execute(
- self.tables.autoinc_pk.select().
- where(self.tables.autoinc_pk.c.id != None)
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
)
assert len(r.fetchall())
@@ -150,15 +150,15 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
result = config.db.execute(
- dest_table.insert().
- from_select(
+ dest_table.insert().from_select(
("data",),
- select([src_table.c.data]).
- where(src_table.c.data.in_(["data2", "data3"]))
+ select([src_table.c.data]).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
)
)
@@ -167,7 +167,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
result = config.db.execute(
select([dest_table.c.data]).order_by(dest_table.c.data)
)
- eq_(result.fetchall(), [("data2", ), ("data3", )])
+ eq_(result.fetchall(), [("data2",), ("data3",)])
@requirements.insert_from_select
def test_insert_from_select_autoinc_no_rows(self):
@@ -175,11 +175,11 @@ class InsertBehaviorTest(fixtures.TablesTest):
dest_table = self.tables.autoinc_pk
result = config.db.execute(
- dest_table.insert().
- from_select(
+ dest_table.insert().from_select(
("data",),
- select([src_table.c.data]).
- where(src_table.c.data.in_(["data2", "data3"]))
+ select([src_table.c.data]).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
)
)
eq_(result.inserted_primary_key, [None])
@@ -199,23 +199,23 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
config.db.execute(
- table.insert(inline=True).
- from_select(("id", "data",),
- select([table.c.id + 5, table.c.data]).
- where(table.c.data.in_(["data2", "data3"]))
- ),
+ table.insert(inline=True).from_select(
+ ("id", "data"),
+ select([table.c.id + 5, table.c.data]).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
)
eq_(
config.db.execute(
select([table.c.data]).order_by(table.c.data)
).fetchall(),
- [("data1", ), ("data2", ), ("data2", ),
- ("data3", ), ("data3", )]
+ [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
)
@requirements.insert_from_select
@@ -227,56 +227,60 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
config.db.execute(
- table.insert(inline=True).
- from_select(("id", "data",),
- select([table.c.id + 5, table.c.data]).
- where(table.c.data.in_(["data2", "data3"]))
- ),
+ table.insert(inline=True).from_select(
+ ("id", "data"),
+ select([table.c.id + 5, table.c.data]).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
)
eq_(
config.db.execute(
select([table]).order_by(table.c.data, table.c.id)
).fetchall(),
- [(1, 'data1', 5, 4), (2, 'data2', 5, 4),
- (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
+ [
+ (1, "data1", 5, 4),
+ (2, "data2", 5, 4),
+ (7, "data2", 5, 4),
+ (3, "data3", 5, 4),
+ (8, "data3", 5, 4),
+ ],
)
class ReturningTest(fixtures.TablesTest):
- run_create_tables = 'each'
- __requires__ = 'returning', 'autoincrement_insert'
+ run_create_tables = "each"
+ __requires__ = "returning", "autoincrement_insert"
__backend__ = True
__engine_options__ = {"implicit_returning": True}
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (config.db.dialect.default_sequence_base, "some data")
- )
+ eq_(row, (config.db.dialect.default_sequence_base, "some data"))
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
@requirements.fetch_rows_post_commit
def test_explicit_returning_pk_autocommit(self):
engine = config.db
table = self.tables.autoinc_pk
r = engine.execute(
- table.insert().returning(
- table.c.id),
- data="some data"
+ table.insert().returning(table.c.id), data="some data"
)
pk = r.first()[0]
fetched_pk = config.db.scalar(select([table.c.id]))
@@ -287,9 +291,7 @@ class ReturningTest(fixtures.TablesTest):
table = self.tables.autoinc_pk
with engine.begin() as conn:
r = conn.execute(
- table.insert().returning(
- table.c.id),
- data="some data"
+ table.insert().returning(table.c.id), data="some data"
)
pk = r.first()[0]
fetched_pk = config.db.scalar(select([table.c.id]))
@@ -297,23 +299,16 @@ class ReturningTest(fixtures.TablesTest):
def test_autoincrement_on_insert_implcit_returning(self):
- config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.autoinc_pk.insert(), data="some data")
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id_implicit_returning(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- r.inserted_primary_key,
- [pk]
- )
+ eq_(r.inserted_primary_key, [pk])
-__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest')
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 00a5aac01..bfed5f1ab 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -1,5 +1,3 @@
-
-
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
from sqlalchemy import types as sql_types
@@ -26,10 +24,12 @@ class HasTableTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('test_table', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
def test_has_table(self):
with config.db.begin() as conn:
@@ -46,8 +46,10 @@ class ComponentReflectionTest(fixtures.TablesTest):
def setup_bind(cls):
if config.requirements.independent_connections.enabled:
from sqlalchemy import pool
+
return engines.testing_engine(
- options=dict(poolclass=pool.StaticPool))
+ options=dict(poolclass=pool.StaticPool)
+ )
else:
return config.db
@@ -65,86 +67,109 @@ class ComponentReflectionTest(fixtures.TablesTest):
schema_prefix = ""
if testing.requires.self_referential_foreign_keys.enabled:
- users = Table('users', metadata,
- Column('user_id', sa.INT, primary_key=True),
- Column('test1', sa.CHAR(5), nullable=False),
- Column('test2', sa.Float(5), nullable=False),
- Column('parent_user_id', sa.Integer,
- sa.ForeignKey('%susers.user_id' %
- schema_prefix,
- name='user_id_fk')),
- schema=schema,
- test_needs_fk=True,
- )
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ Column(
+ "parent_user_id",
+ sa.Integer,
+ sa.ForeignKey(
+ "%susers.user_id" % schema_prefix, name="user_id_fk"
+ ),
+ ),
+ schema=schema,
+ test_needs_fk=True,
+ )
else:
- users = Table('users', metadata,
- Column('user_id', sa.INT, primary_key=True),
- Column('test1', sa.CHAR(5), nullable=False),
- Column('test2', sa.Float(5), nullable=False),
- schema=schema,
- test_needs_fk=True,
- )
-
- Table("dingalings", metadata,
- Column('dingaling_id', sa.Integer, primary_key=True),
- Column('address_id', sa.Integer,
- sa.ForeignKey('%semail_addresses.address_id' %
- schema_prefix)),
- Column('data', sa.String(30)),
- schema=schema,
- test_needs_fk=True,
- )
- Table('email_addresses', metadata,
- Column('address_id', sa.Integer),
- Column('remote_user_id', sa.Integer,
- sa.ForeignKey(users.c.user_id)),
- Column('email_address', sa.String(20)),
- sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'),
- schema=schema,
- test_needs_fk=True,
- )
- Table('comment_test', metadata,
- Column('id', sa.Integer, primary_key=True, comment='id comment'),
- Column('data', sa.String(20), comment='data % comment'),
- Column(
- 'd2', sa.String(20),
- comment=r"""Comment types type speedily ' " \ '' Fun!"""),
- schema=schema,
- comment=r"""the test % ' " \ table comment""")
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ Table(
+ "dingalings",
+ metadata,
+ Column("dingaling_id", sa.Integer, primary_key=True),
+ Column(
+ "address_id",
+ sa.Integer,
+ sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ ),
+ Column("data", sa.String(30)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "email_addresses",
+ metadata,
+ Column("address_id", sa.Integer),
+ Column(
+ "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
+ ),
+ Column("email_address", sa.String(20)),
+ sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "comment_test",
+ metadata,
+ Column("id", sa.Integer, primary_key=True, comment="id comment"),
+ Column("data", sa.String(20), comment="data % comment"),
+ Column(
+ "d2",
+ sa.String(20),
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ schema=schema,
+ comment=r"""the test % ' " \ table comment""",
+ )
if testing.requires.cross_schema_fk_reflection.enabled:
if schema is None:
Table(
- 'local_table', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('data', sa.String(20)),
+ "local_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
Column(
- 'remote_id',
+ "remote_id",
ForeignKey(
- '%s.remote_table_2.id' %
- testing.config.test_schema)
+ "%s.remote_table_2.id" % testing.config.test_schema
+ ),
),
test_needs_fk=True,
- schema=config.db.dialect.default_schema_name
+ schema=config.db.dialect.default_schema_name,
)
else:
Table(
- 'remote_table', metadata,
- Column('id', sa.Integer, primary_key=True),
+ "remote_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
Column(
- 'local_id',
+ "local_id",
ForeignKey(
- '%s.local_table.id' %
- config.db.dialect.default_schema_name)
+ "%s.local_table.id"
+ % config.db.dialect.default_schema_name
+ ),
),
- Column('data', sa.String(20)),
+ Column("data", sa.String(20)),
schema=schema,
test_needs_fk=True,
)
Table(
- 'remote_table_2', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('data', sa.String(20)),
+ "remote_table_2",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
schema=schema,
test_needs_fk=True,
)
@@ -155,19 +180,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
if not schema:
# test_needs_fk is at the moment to force MySQL InnoDB
noncol_idx_test_nopk = Table(
- 'noncol_idx_test_nopk', metadata,
- Column('q', sa.String(5)),
+ "noncol_idx_test_nopk",
+ metadata,
+ Column("q", sa.String(5)),
test_needs_fk=True,
)
noncol_idx_test_pk = Table(
- 'noncol_idx_test_pk', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('q', sa.String(5)),
+ "noncol_idx_test_pk",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("q", sa.String(5)),
test_needs_fk=True,
)
- Index('noncol_idx_nopk', noncol_idx_test_nopk.c.q.desc())
- Index('noncol_idx_pk', noncol_idx_test_pk.c.q.desc())
+ Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
+ Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
if testing.requires.view_column_reflection.enabled:
cls.define_views(metadata, schema)
@@ -180,34 +207,35 @@ class ComponentReflectionTest(fixtures.TablesTest):
# temp table fixture
if testing.against("oracle"):
kw = {
- 'prefixes': ["GLOBAL TEMPORARY"],
- 'oracle_on_commit': 'PRESERVE ROWS'
+ "prefixes": ["GLOBAL TEMPORARY"],
+ "oracle_on_commit": "PRESERVE ROWS",
}
else:
- kw = {
- 'prefixes': ["TEMPORARY"],
- }
+ kw = {"prefixes": ["TEMPORARY"]}
user_tmp = Table(
- "user_tmp", metadata,
+ "user_tmp",
+ metadata,
Column("id", sa.INT, primary_key=True),
- Column('name', sa.VARCHAR(50)),
- Column('foo', sa.INT),
- sa.UniqueConstraint('name', name='user_tmp_uq'),
+ Column("name", sa.VARCHAR(50)),
+ Column("foo", sa.INT),
+ sa.UniqueConstraint("name", name="user_tmp_uq"),
sa.Index("user_tmp_ix", "foo"),
**kw
)
- if testing.requires.view_reflection.enabled and \
- testing.requires.temporary_views.enabled:
- event.listen(
- user_tmp, "after_create",
- DDL("create temporary view user_tmp_v as "
- "select * from user_tmp")
- )
+ if (
+ testing.requires.view_reflection.enabled
+ and testing.requires.temporary_views.enabled
+ ):
event.listen(
- user_tmp, "before_drop",
- DDL("drop view user_tmp_v")
+ user_tmp,
+ "after_create",
+ DDL(
+ "create temporary view user_tmp_v as "
+ "select * from user_tmp"
+ ),
)
+ event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
@classmethod
def define_index(cls, metadata, users):
@@ -216,23 +244,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
@classmethod
def define_views(cls, metadata, schema):
- for table_name in ('users', 'email_addresses'):
+ for table_name in ("users", "email_addresses"):
fullname = table_name
if schema:
fullname = "%s.%s" % (schema, table_name)
- view_name = fullname + '_v'
+ view_name = fullname + "_v"
query = "CREATE VIEW %s AS SELECT * FROM %s" % (
- view_name, fullname)
-
- event.listen(
- metadata,
- "after_create",
- DDL(query)
+ view_name,
+ fullname,
)
+
+ event.listen(metadata, "after_create", DDL(query))
event.listen(
- metadata,
- "before_drop",
- DDL("DROP VIEW %s" % view_name)
+ metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
)
@testing.requires.schema_reflection
@@ -244,9 +268,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schema_reflection
def test_dialect_initialize(self):
engine = engines.testing_engine()
- assert not hasattr(engine.dialect, 'default_schema_name')
+ assert not hasattr(engine.dialect, "default_schema_name")
inspect(engine)
- assert hasattr(engine.dialect, 'default_schema_name')
+ assert hasattr(engine.dialect, "default_schema_name")
@testing.requires.schema_reflection
def test_get_default_schema_name(self):
@@ -254,40 +278,49 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(insp.default_schema_name, testing.db.dialect.default_schema_name)
@testing.provide_metadata
- def _test_get_table_names(self, schema=None, table_type='table',
- order_by=None):
+ def _test_get_table_names(
+ self, schema=None, table_type="table", order_by=None
+ ):
_ignore_tables = [
- 'comment_test', 'noncol_idx_test_pk', 'noncol_idx_test_nopk',
- 'local_table', 'remote_table', 'remote_table_2'
+ "comment_test",
+ "noncol_idx_test_pk",
+ "noncol_idx_test_nopk",
+ "local_table",
+ "remote_table",
+ "remote_table_2",
]
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
- if table_type == 'view':
+ if table_type == "view":
table_names = insp.get_view_names(schema)
table_names.sort()
- answer = ['email_addresses_v', 'users_v']
+ answer = ["email_addresses_v", "users_v"]
eq_(sorted(table_names), answer)
else:
table_names = [
- t for t in insp.get_table_names(
- schema,
- order_by=order_by) if t not in _ignore_tables]
+ t
+ for t in insp.get_table_names(schema, order_by=order_by)
+ if t not in _ignore_tables
+ ]
- if order_by == 'foreign_key':
- answer = ['users', 'email_addresses', 'dingalings']
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
eq_(table_names, answer)
else:
- answer = ['dingalings', 'email_addresses', 'users']
+ answer = ["dingalings", "email_addresses", "users"]
eq_(sorted(table_names), answer)
@testing.requires.temp_table_names
def test_get_temp_table_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_table_names()
- eq_(sorted(temp_table_names), ['user_tmp'])
+ eq_(sorted(temp_table_names), ["user_tmp"])
@testing.requires.view_reflection
@testing.requires.temp_table_names
@@ -295,7 +328,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
def test_get_temp_view_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_view_names()
- eq_(sorted(temp_table_names), ['user_tmp_v'])
+ eq_(sorted(temp_table_names), ["user_tmp_v"])
@testing.requires.table_reflection
def test_get_table_names(self):
@@ -304,7 +337,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.requires.foreign_key_constraint_reflection
def test_get_table_names_fks(self):
- self._test_get_table_names(order_by='foreign_key')
+ self._test_get_table_names(order_by="foreign_key")
@testing.requires.comment_reflection
def test_get_comments(self):
@@ -320,26 +353,24 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(
insp.get_table_comment("comment_test", schema=schema),
- {"text": r"""the test % ' " \ table comment"""}
+ {"text": r"""the test % ' " \ table comment"""},
)
- eq_(
- insp.get_table_comment("users", schema=schema),
- {"text": None}
- )
+ eq_(insp.get_table_comment("users", schema=schema), {"text": None})
eq_(
[
- {"name": rec['name'], "comment": rec['comment']}
- for rec in
- insp.get_columns("comment_test", schema=schema)
+ {"name": rec["name"], "comment": rec["comment"]}
+ for rec in insp.get_columns("comment_test", schema=schema)
],
[
- {'comment': 'id comment', 'name': 'id'},
- {'comment': 'data % comment', 'name': 'data'},
- {'comment': r"""Comment types type speedily ' " \ '' Fun!""",
- 'name': 'd2'}
- ]
+ {"comment": "id comment", "name": "id"},
+ {"comment": "data % comment", "name": "data"},
+ {
+ "comment": r"""Comment types type speedily ' " \ '' Fun!""",
+ "name": "d2",
+ },
+ ],
)
@testing.requires.table_reflection
@@ -349,30 +380,33 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.view_column_reflection
def test_get_view_names(self):
- self._test_get_table_names(table_type='view')
+ self._test_get_table_names(table_type="view")
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_names_with_schema(self):
self._test_get_table_names(
- testing.config.test_schema, table_type='view')
+ testing.config.test_schema, table_type="view"
+ )
@testing.requires.table_reflection
@testing.requires.view_column_reflection
def test_get_tables_and_views(self):
self._test_get_table_names()
- self._test_get_table_names(table_type='view')
+ self._test_get_table_names(table_type="view")
- def _test_get_columns(self, schema=None, table_type='table'):
+ def _test_get_columns(self, schema=None, table_type="table"):
meta = MetaData(testing.db)
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
- table_names = ['users', 'email_addresses']
- if table_type == 'view':
- table_names = ['users_v', 'email_addresses_v']
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
+ table_names = ["users", "email_addresses"]
+ if table_type == "view":
+ table_names = ["users_v", "email_addresses_v"]
insp = inspect(meta.bind)
- for table_name, table in zip(table_names, (users,
- addresses)):
+ for table_name, table in zip(table_names, (users, addresses)):
schema_name = schema
cols = insp.get_columns(table_name, schema=schema_name)
self.assert_(len(cols) > 0, len(cols))
@@ -380,36 +414,46 @@ class ComponentReflectionTest(fixtures.TablesTest):
# should be in order
for i, col in enumerate(table.columns):
- eq_(col.name, cols[i]['name'])
- ctype = cols[i]['type'].__class__
+ eq_(col.name, cols[i]["name"])
+ ctype = cols[i]["type"].__class__
ctype_def = col.type
if isinstance(ctype_def, sa.types.TypeEngine):
ctype_def = ctype_def.__class__
# Oracle returns Date for DateTime.
- if testing.against('oracle') and ctype_def \
- in (sql_types.Date, sql_types.DateTime):
+ if testing.against("oracle") and ctype_def in (
+ sql_types.Date,
+ sql_types.DateTime,
+ ):
ctype_def = sql_types.Date
# assert that the desired type and return type share
# a base within one of the generic types.
- self.assert_(len(set(ctype.__mro__).
- intersection(ctype_def.__mro__).
- intersection([
- sql_types.Integer,
- sql_types.Numeric,
- sql_types.DateTime,
- sql_types.Date,
- sql_types.Time,
- sql_types.String,
- sql_types._Binary,
- ])) > 0, '%s(%s), %s(%s)' %
- (col.name, col.type, cols[i]['name'], ctype))
+ self.assert_(
+ len(
+ set(ctype.__mro__)
+ .intersection(ctype_def.__mro__)
+ .intersection(
+ [
+ sql_types.Integer,
+ sql_types.Numeric,
+ sql_types.DateTime,
+ sql_types.Date,
+ sql_types.Time,
+ sql_types.String,
+ sql_types._Binary,
+ ]
+ )
+ )
+ > 0,
+ "%s(%s), %s(%s)"
+ % (col.name, col.type, cols[i]["name"], ctype),
+ )
if not col.primary_key:
- assert cols[i]['default'] is None
+ assert cols[i]["default"] is None
@testing.requires.table_reflection
def test_get_columns(self):
@@ -417,24 +461,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.provide_metadata
def _type_round_trip(self, *types):
- t = Table('t', self.metadata,
- *[
- Column('t%d' % i, type_)
- for i, type_ in enumerate(types)
- ]
- )
+ t = Table(
+ "t",
+ self.metadata,
+ *[Column("t%d" % i, type_) for i, type_ in enumerate(types)]
+ )
t.create()
return [
- c['type'] for c in
- inspect(self.metadata.bind).get_columns('t')
+ c["type"] for c in inspect(self.metadata.bind).get_columns("t")
]
@testing.requires.table_reflection
def test_numeric_reflection(self):
- for typ in self._type_round_trip(
- sql_types.Numeric(18, 5),
- ):
+ for typ in self._type_round_trip(sql_types.Numeric(18, 5)):
assert isinstance(typ, sql_types.Numeric)
eq_(typ.precision, 18)
eq_(typ.scale, 5)
@@ -448,16 +488,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.provide_metadata
def test_nullable_reflection(self):
- t = Table('t', self.metadata,
- Column('a', Integer, nullable=True),
- Column('b', Integer, nullable=False))
+ t = Table(
+ "t",
+ self.metadata,
+ Column("a", Integer, nullable=True),
+ Column("b", Integer, nullable=False),
+ )
t.create()
eq_(
dict(
- (col['name'], col['nullable'])
- for col in inspect(self.metadata.bind).get_columns('t')
+ (col["name"], col["nullable"])
+ for col in inspect(self.metadata.bind).get_columns("t")
),
- {"a": True, "b": False}
+ {"a": True, "b": False},
)
@testing.requires.table_reflection
@@ -470,32 +513,30 @@ class ComponentReflectionTest(fixtures.TablesTest):
meta = MetaData(self.bind)
user_tmp = self.tables.user_tmp
insp = inspect(meta.bind)
- cols = insp.get_columns('user_tmp')
+ cols = insp.get_columns("user_tmp")
self.assert_(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
- eq_(col.name, cols[i]['name'])
+ eq_(col.name, cols[i]["name"])
@testing.requires.temp_table_reflection
@testing.requires.view_column_reflection
@testing.requires.temporary_views
def test_get_temp_view_columns(self):
insp = inspect(self.bind)
- cols = insp.get_columns('user_tmp_v')
- eq_(
- [col['name'] for col in cols],
- ['id', 'name', 'foo']
- )
+ cols = insp.get_columns("user_tmp_v")
+ eq_([col["name"] for col in cols], ["id", "name", "foo"])
@testing.requires.view_column_reflection
def test_get_view_columns(self):
- self._test_get_columns(table_type='view')
+ self._test_get_columns(table_type="view")
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_columns_with_schema(self):
self._test_get_columns(
- schema=testing.config.test_schema, table_type='view')
+ schema=testing.config.test_schema, table_type="view"
+ )
@testing.provide_metadata
def _test_get_pk_constraint(self, schema=None):
@@ -504,15 +545,15 @@ class ComponentReflectionTest(fixtures.TablesTest):
insp = inspect(meta.bind)
users_cons = insp.get_pk_constraint(users.name, schema=schema)
- users_pkeys = users_cons['constrained_columns']
- eq_(users_pkeys, ['user_id'])
+ users_pkeys = users_cons["constrained_columns"]
+ eq_(users_pkeys, ["user_id"])
addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
- addr_pkeys = addr_cons['constrained_columns']
- eq_(addr_pkeys, ['address_id'])
+ addr_pkeys = addr_cons["constrained_columns"]
+ eq_(addr_pkeys, ["address_id"])
with testing.requires.reflects_pk_names.fail_if():
- eq_(addr_cons['name'], 'email_ad_pk')
+ eq_(addr_cons["name"], "email_ad_pk")
@testing.requires.primary_key_constraint_reflection
def test_get_pk_constraint(self):
@@ -534,44 +575,46 @@ class ComponentReflectionTest(fixtures.TablesTest):
sa_exc.SADeprecationWarning,
"Call to deprecated method get_primary_keys."
" Use get_pk_constraint instead.",
- insp.get_primary_keys, users.name
+ insp.get_primary_keys,
+ users.name,
)
@testing.provide_metadata
def _test_get_foreign_keys(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
expected_schema = schema
# users
if testing.requires.self_referential_foreign_keys.enabled:
- users_fkeys = insp.get_foreign_keys(users.name,
- schema=schema)
+ users_fkeys = insp.get_foreign_keys(users.name, schema=schema)
fkey1 = users_fkeys[0]
with testing.requires.named_constraints.fail_if():
- eq_(fkey1['name'], "user_id_fk")
+ eq_(fkey1["name"], "user_id_fk")
- eq_(fkey1['referred_schema'], expected_schema)
- eq_(fkey1['referred_table'], users.name)
- eq_(fkey1['referred_columns'], ['user_id', ])
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
if testing.requires.self_referential_foreign_keys.enabled:
- eq_(fkey1['constrained_columns'], ['parent_user_id'])
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
# addresses
- addr_fkeys = insp.get_foreign_keys(addresses.name,
- schema=schema)
+ addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
fkey1 = addr_fkeys[0]
with testing.requires.implicitly_named_constraints.fail_if():
- self.assert_(fkey1['name'] is not None)
+ self.assert_(fkey1["name"] is not None)
- eq_(fkey1['referred_schema'], expected_schema)
- eq_(fkey1['referred_table'], users.name)
- eq_(fkey1['referred_columns'], ['user_id', ])
- eq_(fkey1['constrained_columns'], ['remote_user_id'])
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ eq_(fkey1["constrained_columns"], ["remote_user_id"])
@testing.requires.foreign_key_constraint_reflection
def test_get_foreign_keys(self):
@@ -586,9 +629,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schemas
def test_get_inter_schema_foreign_keys(self):
local_table, remote_table, remote_table_2 = self.tables(
- '%s.local_table' % testing.db.dialect.default_schema_name,
- '%s.remote_table' % testing.config.test_schema,
- '%s.remote_table_2' % testing.config.test_schema
+ "%s.local_table" % testing.db.dialect.default_schema_name,
+ "%s.remote_table" % testing.config.test_schema,
+ "%s.remote_table_2" % testing.config.test_schema,
)
insp = inspect(config.db)
@@ -597,25 +640,25 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(len(local_fkeys), 1)
fkey1 = local_fkeys[0]
- eq_(fkey1['referred_schema'], testing.config.test_schema)
- eq_(fkey1['referred_table'], remote_table_2.name)
- eq_(fkey1['referred_columns'], ['id', ])
- eq_(fkey1['constrained_columns'], ['remote_id'])
+ eq_(fkey1["referred_schema"], testing.config.test_schema)
+ eq_(fkey1["referred_table"], remote_table_2.name)
+ eq_(fkey1["referred_columns"], ["id"])
+ eq_(fkey1["constrained_columns"], ["remote_id"])
remote_fkeys = insp.get_foreign_keys(
- remote_table.name, schema=testing.config.test_schema)
+ remote_table.name, schema=testing.config.test_schema
+ )
eq_(len(remote_fkeys), 1)
fkey2 = remote_fkeys[0]
- assert fkey2['referred_schema'] in (
+ assert fkey2["referred_schema"] in (
None,
- testing.db.dialect.default_schema_name
+ testing.db.dialect.default_schema_name,
)
- eq_(fkey2['referred_table'], local_table.name)
- eq_(fkey2['referred_columns'], ['id', ])
- eq_(fkey2['constrained_columns'], ['local_id'])
-
+ eq_(fkey2["referred_table"], local_table.name)
+ eq_(fkey2["referred_columns"], ["id"])
+ eq_(fkey2["constrained_columns"], ["local_id"])
@testing.requires.foreign_key_constraint_option_reflection_ondelete
def test_get_foreign_key_options_ondelete(self):
@@ -630,26 +673,32 @@ class ComponentReflectionTest(fixtures.TablesTest):
meta = self.metadata
Table(
- 'x', meta,
- Column('id', Integer, primary_key=True),
- test_needs_fk=True
- )
-
- Table('table', meta,
- Column('id', Integer, primary_key=True),
- Column('x_id', Integer, sa.ForeignKey('x.id', name='xid')),
- Column('test', String(10)),
- test_needs_fk=True)
-
- Table('user', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('tid', Integer),
- sa.ForeignKeyConstraint(
- ['tid'], ['table.id'],
- name='myfk',
- **options),
- test_needs_fk=True)
+ "x",
+ meta,
+ Column("id", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "table",
+ meta,
+ Column("id", Integer, primary_key=True),
+ Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")),
+ Column("test", String(10)),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "user",
+ meta,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ sa.ForeignKeyConstraint(
+ ["tid"], ["table.id"], name="myfk", **options
+ ),
+ test_needs_fk=True,
+ )
meta.create_all()
@@ -657,49 +706,44 @@ class ComponentReflectionTest(fixtures.TablesTest):
# test 'options' is always present for a backend
# that can reflect these, since alembic looks for this
- opts = insp.get_foreign_keys('table')[0]['options']
+ opts = insp.get_foreign_keys("table")[0]["options"]
- eq_(
- dict(
- (k, opts[k])
- for k in opts if opts[k]
- ),
- {}
- )
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), {})
- opts = insp.get_foreign_keys('user')[0]['options']
- eq_(
- dict(
- (k, opts[k])
- for k in opts if opts[k]
- ),
- options
- )
+ opts = insp.get_foreign_keys("user")[0]["options"]
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), options)
def _assert_insp_indexes(self, indexes, expected_indexes):
- index_names = [d['name'] for d in indexes]
+ index_names = [d["name"] for d in indexes]
for e_index in expected_indexes:
- assert e_index['name'] in index_names
- index = indexes[index_names.index(e_index['name'])]
+ assert e_index["name"] in index_names
+ index = indexes[index_names.index(e_index["name"])]
for key in e_index:
eq_(e_index[key], index[key])
@testing.provide_metadata
def _test_get_indexes(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
# The database may decide to create indexes for foreign keys, etc.
# so there may be more indexes than expected.
insp = inspect(meta.bind)
- indexes = insp.get_indexes('users', schema=schema)
+ indexes = insp.get_indexes("users", schema=schema)
expected_indexes = [
- {'unique': False,
- 'column_names': ['test1', 'test2'],
- 'name': 'users_t_idx'},
- {'unique': False,
- 'column_names': ['user_id', 'test2', 'test1'],
- 'name': 'users_all_idx'}
+ {
+ "unique": False,
+ "column_names": ["test1", "test2"],
+ "name": "users_t_idx",
+ },
+ {
+ "unique": False,
+ "column_names": ["user_id", "test2", "test1"],
+ "name": "users_all_idx",
+ },
]
self._assert_insp_indexes(indexes, expected_indexes)
@@ -721,10 +765,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
# reflecting an index that has "x DESC" in it as the column.
# the DB may or may not give us "x", but make sure we get the index
# back, it has a name, it's connected to the table.
- expected_indexes = [
- {'unique': False,
- 'name': ixname}
- ]
+ expected_indexes = [{"unique": False, "name": ixname}]
self._assert_insp_indexes(indexes, expected_indexes)
t = Table(tname, meta, autoload_with=meta.bind)
@@ -748,24 +789,30 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.unique_constraint_reflection
def test_get_temp_table_unique_constraints(self):
insp = inspect(self.bind)
- reflected = insp.get_unique_constraints('user_tmp')
+ reflected = insp.get_unique_constraints("user_tmp")
for refl in reflected:
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
- refl.pop('duplicates_index', None)
- eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}])
+ refl.pop("duplicates_index", None)
+ eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}])
@testing.requires.temp_table_reflection
def test_get_temp_table_indexes(self):
insp = inspect(self.bind)
- indexes = insp.get_indexes('user_tmp')
+ indexes = insp.get_indexes("user_tmp")
for ind in indexes:
- ind.pop('dialect_options', None)
+ ind.pop("dialect_options", None)
eq_(
# TODO: we need to add better filtering for indexes/uq constraints
# that are doubled up
- [idx for idx in indexes if idx['name'] == 'user_tmp_ix'],
- [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}]
+ [idx for idx in indexes if idx["name"] == "user_tmp_ix"],
+ [
+ {
+ "unique": False,
+ "column_names": ["foo"],
+ "name": "user_tmp_ix",
+ }
+ ],
)
@testing.requires.unique_constraint_reflection
@@ -783,36 +830,37 @@ class ComponentReflectionTest(fixtures.TablesTest):
# CREATE TABLE?
uniques = sorted(
[
- {'name': 'unique_a', 'column_names': ['a']},
- {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
- {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']},
- {'name': 'unique_asc_key', 'column_names': ['asc', 'key']},
- {'name': 'i.have.dots', 'column_names': ['b']},
- {'name': 'i have spaces', 'column_names': ['c']},
+ {"name": "unique_a", "column_names": ["a"]},
+ {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]},
+ {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]},
+ {"name": "unique_asc_key", "column_names": ["asc", "key"]},
+ {"name": "i.have.dots", "column_names": ["b"]},
+ {"name": "i have spaces", "column_names": ["c"]},
],
- key=operator.itemgetter('name')
+ key=operator.itemgetter("name"),
)
orig_meta = self.metadata
table = Table(
- 'testtbl', orig_meta,
- Column('a', sa.String(20)),
- Column('b', sa.String(30)),
- Column('c', sa.Integer),
+ "testtbl",
+ orig_meta,
+ Column("a", sa.String(20)),
+ Column("b", sa.String(30)),
+ Column("c", sa.Integer),
# reserved identifiers
- Column('asc', sa.String(30)),
- Column('key', sa.String(30)),
- schema=schema
+ Column("asc", sa.String(30)),
+ Column("key", sa.String(30)),
+ schema=schema,
)
for uc in uniques:
table.append_constraint(
- sa.UniqueConstraint(*uc['column_names'], name=uc['name'])
+ sa.UniqueConstraint(*uc["column_names"], name=uc["name"])
)
orig_meta.create_all()
inspector = inspect(orig_meta.bind)
reflected = sorted(
- inspector.get_unique_constraints('testtbl', schema=schema),
- key=operator.itemgetter('name')
+ inspector.get_unique_constraints("testtbl", schema=schema),
+ key=operator.itemgetter("name"),
)
names_that_duplicate_index = set()
@@ -820,25 +868,31 @@ class ComponentReflectionTest(fixtures.TablesTest):
for orig, refl in zip(uniques, reflected):
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
- dupe = refl.pop('duplicates_index', None)
+ dupe = refl.pop("duplicates_index", None)
if dupe:
names_that_duplicate_index.add(dupe)
eq_(orig, refl)
reflected_metadata = MetaData()
reflected = Table(
- 'testtbl', reflected_metadata, autoload_with=orig_meta.bind,
- schema=schema)
+ "testtbl",
+ reflected_metadata,
+ autoload_with=orig_meta.bind,
+ schema=schema,
+ )
# test "deduplicates for index" logic. MySQL and Oracle
# "unique constraints" are actually unique indexes (with possible
# exception of a unique that is a dupe of another one in the case
# of Oracle). make sure # they aren't duplicated.
idx_names = set([idx.name for idx in reflected.indexes])
- uq_names = set([
- uq.name for uq in reflected.constraints
- if isinstance(uq, sa.UniqueConstraint)]).difference(
- ['unique_c_a_b'])
+ uq_names = set(
+ [
+ uq.name
+ for uq in reflected.constraints
+ if isinstance(uq, sa.UniqueConstraint)
+ ]
+ ).difference(["unique_c_a_b"])
assert not idx_names.intersection(uq_names)
if names_that_duplicate_index:
@@ -858,47 +912,52 @@ class ComponentReflectionTest(fixtures.TablesTest):
def _test_get_check_constraints(self, schema=None):
orig_meta = self.metadata
Table(
- 'sa_cc', orig_meta,
- Column('a', Integer()),
- sa.CheckConstraint('a > 1 AND a < 5', name='cc1'),
- sa.CheckConstraint('a = 1 OR (a > 2 AND a < 5)', name='cc2'),
- schema=schema
+ "sa_cc",
+ orig_meta,
+ Column("a", Integer()),
+ sa.CheckConstraint("a > 1 AND a < 5", name="cc1"),
+ sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"),
+ schema=schema,
)
orig_meta.create_all()
inspector = inspect(orig_meta.bind)
reflected = sorted(
- inspector.get_check_constraints('sa_cc', schema=schema),
- key=operator.itemgetter('name')
+ inspector.get_check_constraints("sa_cc", schema=schema),
+ key=operator.itemgetter("name"),
)
# trying to minimize effect of quoting, parenthesis, etc.
# may need to add more to this as new dialects get CHECK
# constraint reflection support
def normalize(sqltext):
- return " ".join(re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I))
+ return " ".join(
+ re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)
+ )
reflected = [
- {"name": item["name"],
- "sqltext": normalize(item["sqltext"])}
+ {"name": item["name"], "sqltext": normalize(item["sqltext"])}
for item in reflected
]
eq_(
reflected,
[
- {'name': 'cc1', 'sqltext': 'a > 1 and a < 5'},
- {'name': 'cc2', 'sqltext': 'a = 1 or a > 2 and a < 5'}
- ]
+ {"name": "cc1", "sqltext": "a > 1 and a < 5"},
+ {"name": "cc2", "sqltext": "a = 1 or a > 2 and a < 5"},
+ ],
)
@testing.provide_metadata
def _test_get_view_definition(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
- view_name1 = 'users_v'
- view_name2 = 'email_addresses_v'
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
+ view_name1 = "users_v"
+ view_name2 = "email_addresses_v"
insp = inspect(meta.bind)
v1 = insp.get_view_definition(view_name1, schema=schema)
self.assert_(v1)
@@ -918,18 +977,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.provide_metadata
def _test_get_table_oid(self, table_name, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
oid = insp.get_table_oid(table_name, schema)
self.assert_(isinstance(oid, int))
def test_get_table_oid(self):
- self._test_get_table_oid('users')
+ self._test_get_table_oid("users")
@testing.requires.schemas
def test_get_table_oid_with_schema(self):
- self._test_get_table_oid('users', schema=testing.config.test_schema)
+ self._test_get_table_oid("users", schema=testing.config.test_schema)
@testing.requires.table_reflection
@testing.provide_metadata
@@ -950,49 +1012,53 @@ class ComponentReflectionTest(fixtures.TablesTest):
insp = inspect(meta.bind)
for tname, cname in [
- ('users', 'user_id'),
- ('email_addresses', 'address_id'),
- ('dingalings', 'dingaling_id'),
+ ("users", "user_id"),
+ ("email_addresses", "address_id"),
+ ("dingalings", "dingaling_id"),
]:
cols = insp.get_columns(tname)
- id_ = {c['name']: c for c in cols}[cname]
- assert id_.get('autoincrement', True)
+ id_ = {c["name"]: c for c in cols}[cname]
+ assert id_.get("autoincrement", True)
class NormalizedNameTest(fixtures.TablesTest):
- __requires__ = 'denormalized_names',
+ __requires__ = ("denormalized_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
- quoted_name('t1', quote=True), metadata,
- Column('id', Integer, primary_key=True),
+ quoted_name("t1", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
)
Table(
- quoted_name('t2', quote=True), metadata,
- Column('id', Integer, primary_key=True),
- Column('t1id', ForeignKey('t1.id'))
+ quoted_name("t2", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("t1id", ForeignKey("t1.id")),
)
def test_reflect_lowercase_forced_tables(self):
m2 = MetaData(testing.db)
- t2_ref = Table(quoted_name('t2', quote=True), m2, autoload=True)
- t1_ref = m2.tables['t1']
+ t2_ref = Table(quoted_name("t2", quote=True), m2, autoload=True)
+ t1_ref = m2.tables["t1"]
assert t2_ref.c.t1id.references(t1_ref.c.id)
m3 = MetaData(testing.db)
- m3.reflect(only=lambda name, m: name.lower() in ('t1', 't2'))
- assert m3.tables['t2'].c.t1id.references(m3.tables['t1'].c.id)
+ m3.reflect(only=lambda name, m: name.lower() in ("t1", "t2"))
+ assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id)
def test_get_table_names(self):
tablenames = [
- t for t in inspect(testing.db).get_table_names()
- if t.lower() in ("t1", "t2")]
+ t
+ for t in inspect(testing.db).get_table_names()
+ if t.lower() in ("t1", "t2")
+ ]
eq_(tablenames[0].upper(), tablenames[0].lower())
eq_(tablenames[1].upper(), tablenames[1].lower())
-__all__ = ('ComponentReflectionTest', 'HasTableTest', 'NormalizedNameTest')
+__all__ = ("ComponentReflectionTest", "HasTableTest", "NormalizedNameTest")
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index f464d47eb..247f05cf5 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -15,14 +15,18 @@ class RowFetchTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('plain_pk', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
- Table('has_dates', metadata,
- Column('id', Integer, primary_key=True),
- Column('today', DateTime)
- )
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Table(
+ "has_dates",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("today", DateTime),
+ )
@classmethod
def insert_data(cls):
@@ -32,65 +36,51 @@ class RowFetchTest(fixtures.TablesTest):
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
- ]
+ ],
)
config.db.execute(
cls.tables.has_dates.insert(),
- [
- {"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}
- ]
+ [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
)
def test_via_string(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row['id'], 1
- )
- eq_(
- row['data'], "d1"
- )
+ eq_(row["id"], 1)
+ eq_(row["data"], "d1")
def test_via_int(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row[0], 1
- )
- eq_(
- row[1], "d1"
- )
+ eq_(row[0], 1)
+ eq_(row[1], "d1")
def test_via_col_object(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row[self.tables.plain_pk.c.id], 1
- )
- eq_(
- row[self.tables.plain_pk.c.data], "d1"
- )
+ eq_(row[self.tables.plain_pk.c.id], 1)
+ eq_(row[self.tables.plain_pk.c.data], "d1")
@requirements.duplicate_names_in_cursor_description
def test_row_with_dupe_names(self):
result = config.db.execute(
- select([self.tables.plain_pk.c.data,
- self.tables.plain_pk.c.data.label('data')]).
- order_by(self.tables.plain_pk.c.id)
+ select(
+ [
+ self.tables.plain_pk.c.data,
+ self.tables.plain_pk.c.data.label("data"),
+ ]
+ ).order_by(self.tables.plain_pk.c.id)
)
row = result.first()
- eq_(result.keys(), ['data', 'data'])
- eq_(row, ('d1', 'd1'))
+ eq_(result.keys(), ["data", "data"])
+ eq_(row, ("d1", "d1"))
def test_row_w_scalar_select(self):
"""test that a scalar select as a column is returned as such
@@ -101,11 +91,11 @@ class RowFetchTest(fixtures.TablesTest):
"""
datetable = self.tables.has_dates
- s = select([datetable.alias('x').c.today]).as_scalar()
- s2 = select([datetable.c.id, s.label('somelabel')])
+ s = select([datetable.alias("x").c.today]).as_scalar()
+ s2 = select([datetable.c.id, s.label("somelabel")])
row = config.db.execute(s2).first()
- eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0))
+ eq_(row["somelabel"], datetime.datetime(2006, 5, 12, 12, 0, 0))
class PercentSchemaNamesTest(fixtures.TablesTest):
@@ -117,29 +107,31 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
"""
- __requires__ = ('percent_schema_names', )
+ __requires__ = ("percent_schema_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- cls.tables.percent_table = Table('percent%table', metadata,
- Column("percent%", Integer),
- Column(
- "spaces % more spaces", Integer),
- )
+ cls.tables.percent_table = Table(
+ "percent%table",
+ metadata,
+ Column("percent%", Integer),
+ Column("spaces % more spaces", Integer),
+ )
cls.tables.lightweight_percent_table = sql.table(
- 'percent%table', sql.column("percent%"),
- sql.column("spaces % more spaces")
+ "percent%table",
+ sql.column("percent%"),
+ sql.column("spaces % more spaces"),
)
def test_single_roundtrip(self):
percent_table = self.tables.percent_table
for params in [
- {'percent%': 5, 'spaces % more spaces': 12},
- {'percent%': 7, 'spaces % more spaces': 11},
- {'percent%': 9, 'spaces % more spaces': 10},
- {'percent%': 11, 'spaces % more spaces': 9}
+ {"percent%": 5, "spaces % more spaces": 12},
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
]:
config.db.execute(percent_table.insert(), params)
self._assert_table()
@@ -147,14 +139,15 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
def test_executemany_roundtrip(self):
percent_table = self.tables.percent_table
config.db.execute(
- percent_table.insert(),
- {'percent%': 5, 'spaces % more spaces': 12}
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
)
config.db.execute(
percent_table.insert(),
- [{'percent%': 7, 'spaces % more spaces': 11},
- {'percent%': 9, 'spaces % more spaces': 10},
- {'percent%': 11, 'spaces % more spaces': 9}]
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
)
self._assert_table()
@@ -163,85 +156,81 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
lightweight_percent_table = self.tables.lightweight_percent_table
for table in (
- percent_table,
- percent_table.alias(),
- lightweight_percent_table,
- lightweight_percent_table.alias()):
+ percent_table,
+ percent_table.alias(),
+ lightweight_percent_table,
+ lightweight_percent_table.alias(),
+ ):
eq_(
list(
config.db.execute(
- table.select().order_by(table.c['percent%'])
+ table.select().order_by(table.c["percent%"])
)
),
- [
- (5, 12),
- (7, 11),
- (9, 10),
- (11, 9)
- ]
+ [(5, 12), (7, 11), (9, 10), (11, 9)],
)
eq_(
list(
config.db.execute(
- table.select().
- where(table.c['spaces % more spaces'].in_([9, 10])).
- order_by(table.c['percent%']),
+ table.select()
+ .where(table.c["spaces % more spaces"].in_([9, 10]))
+ .order_by(table.c["percent%"])
)
),
- [
- (9, 10),
- (11, 9)
- ]
+ [(9, 10), (11, 9)],
)
- row = config.db.execute(table.select().
- order_by(table.c['percent%'])).first()
- eq_(row['percent%'], 5)
- eq_(row['spaces % more spaces'], 12)
+ row = config.db.execute(
+ table.select().order_by(table.c["percent%"])
+ ).first()
+ eq_(row["percent%"], 5)
+ eq_(row["spaces % more spaces"], 12)
- eq_(row[table.c['percent%']], 5)
- eq_(row[table.c['spaces % more spaces']], 12)
+ eq_(row[table.c["percent%"]], 5)
+ eq_(row[table.c["spaces % more spaces"]], 12)
config.db.execute(
percent_table.update().values(
- {percent_table.c['spaces % more spaces']: 15}
+ {percent_table.c["spaces % more spaces"]: 15}
)
)
eq_(
list(
config.db.execute(
- percent_table.
- select().
- order_by(percent_table.c['percent%'])
+ percent_table.select().order_by(
+ percent_table.c["percent%"]
+ )
)
),
- [(5, 15), (7, 15), (9, 15), (11, 15)]
+ [(5, 15), (7, 15), (9, 15), (11, 15)],
)
-class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
+class ServerSideCursorsTest(
+ fixtures.TestBase, testing.AssertsExecutionResults
+):
- __requires__ = ('server_side_cursors', )
+ __requires__ = ("server_side_cursors",)
__backend__ = True
def _is_server_side(self, cursor):
if self.engine.dialect.driver == "psycopg2":
return cursor.name
- elif self.engine.dialect.driver == 'pymysql':
- sscursor = __import__('pymysql.cursors').cursors.SSCursor
+ elif self.engine.dialect.driver == "pymysql":
+ sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver == "mysqldb":
- sscursor = __import__('MySQLdb.cursors').cursors.SSCursor
+ sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
else:
return False
def _fixture(self, server_side_cursors):
self.engine = engines.testing_engine(
- options={'server_side_cursors': server_side_cursors}
+ options={"server_side_cursors": server_side_cursors}
)
return self.engine
@@ -251,12 +240,12 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_global_string(self):
engine = self._fixture(True)
- result = engine.execute('select 1')
+ result = engine.execute("select 1")
assert self._is_server_side(result.cursor)
def test_global_text(self):
engine = self._fixture(True)
- result = engine.execute(text('select 1'))
+ result = engine.execute(text("select 1"))
assert self._is_server_side(result.cursor)
def test_global_expr(self):
@@ -266,7 +255,7 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_global_off_explicit(self):
engine = self._fixture(False)
- result = engine.execute(text('select 1'))
+ result = engine.execute(text("select 1"))
# It should be off globally ...
@@ -286,10 +275,11 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
engine = self._fixture(False)
# and this one
- result = \
- engine.connect().execution_options(stream_results=True).\
- execute('select 1'
- )
+ result = (
+ engine.connect()
+ .execution_options(stream_results=True)
+ .execute("select 1")
+ )
assert self._is_server_side(result.cursor)
def test_stmt_enabled_conn_option_disabled(self):
@@ -298,9 +288,9 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
s = select([1]).execution_options(stream_results=True)
# not this one
- result = \
- engine.connect().execution_options(stream_results=False).\
- execute(s)
+ result = (
+ engine.connect().execution_options(stream_results=False).execute(s)
+ )
assert not self._is_server_side(result.cursor)
def test_stmt_option_disabled(self):
@@ -329,18 +319,18 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_for_update_string(self):
engine = self._fixture(True)
- result = engine.execute('SELECT 1 FOR UPDATE')
+ result = engine.execute("SELECT 1 FOR UPDATE")
assert self._is_server_side(result.cursor)
def test_text_no_ss(self):
engine = self._fixture(False)
- s = text('select 42')
+ s = text("select 42")
result = engine.execute(s)
assert not self._is_server_side(result.cursor)
def test_text_ss_option(self):
engine = self._fixture(False)
- s = text('select 42').execution_options(stream_results=True)
+ s = text("select 42").execution_options(stream_results=True)
result = engine.execute(s)
assert self._is_server_side(result.cursor)
@@ -349,19 +339,25 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
md = self.metadata
engine = self._fixture(True)
- test_table = Table('test_table', md,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)))
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
test_table.create(checkfirst=True)
- test_table.insert().execute(data='data1')
- test_table.insert().execute(data='data2')
- eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, 'data1'), (2, 'data2')])
- test_table.update().where(
- test_table.c.id == 2).values(
- data=test_table.c.data +
- ' updated').execute()
- eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, 'data1'), (2, 'data2 updated')])
+ test_table.insert().execute(data="data1")
+ test_table.insert().execute(data="data2")
+ eq_(
+ test_table.select().order_by(test_table.c.id).execute().fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ test_table.update().where(test_table.c.id == 2).values(
+ data=test_table.c.data + " updated"
+ ).execute()
+ eq_(
+ test_table.select().order_by(test_table.c.id).execute().fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
test_table.delete().execute()
- eq_(select([func.count('*')]).select_from(test_table).scalar(), 0)
+ eq_(select([func.count("*")]).select_from(test_table).scalar(), 0)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index 73ce02492..032b68eb6 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -16,10 +16,12 @@ class CollateTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(100))
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
@classmethod
def insert_data(cls):
@@ -28,26 +30,21 @@ class CollateTest(fixtures.TablesTest):
[
{"id": 1, "data": "collate data1"},
{"id": 2, "data": "collate data2"},
- ]
+ ],
)
def _assert_result(self, select, result):
- eq_(
- config.db.execute(select).fetchall(),
- result
- )
+ eq_(config.db.execute(select).fetchall(), result)
@testing.requires.order_by_collation
def test_collate_order_by(self):
collation = testing.requires.get_order_by_collation(testing.config)
self._assert_result(
- select([self.tables.some_table]).
- order_by(self.tables.some_table.c.data.collate(collation).asc()),
- [
- (1, "collate data1"),
- (2, "collate data2"),
- ]
+ select([self.tables.some_table]).order_by(
+ self.tables.some_table.c.data.collate(collation).asc()
+ ),
+ [(1, "collate data1"), (2, "collate data2")],
)
@@ -59,17 +56,20 @@ class OrderByLabelTest(fixtures.TablesTest):
setting.
"""
+
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer),
- Column('q', String(50)),
- Column('p', String(50))
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", String(50)),
+ Column("p", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -79,65 +79,55 @@ class OrderByLabelTest(fixtures.TablesTest):
{"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
{"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
{"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
- ]
+ ],
)
def _assert_result(self, select, result):
- eq_(
- config.db.execute(select).fetchall(),
- result
- )
+ eq_(config.db.execute(select).fetchall(), result)
def test_plain(self):
table = self.tables.some_table
- lx = table.c.x.label('lx')
- self._assert_result(
- select([lx]).order_by(lx),
- [(1, ), (2, ), (3, )]
- )
+ lx = table.c.x.label("lx")
+ self._assert_result(select([lx]).order_by(lx), [(1,), (2,), (3,)])
def test_composed_int(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
- self._assert_result(
- select([lx]).order_by(lx),
- [(3, ), (5, ), (7, )]
- )
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select([lx]).order_by(lx), [(3,), (5,), (7,)])
def test_composed_multiple(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
- ly = (func.lower(table.c.q) + table.c.p).label('ly')
+ lx = (table.c.x + table.c.y).label("lx")
+ ly = (func.lower(table.c.q) + table.c.p).label("ly")
self._assert_result(
select([lx, ly]).order_by(lx, ly.desc()),
- [(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))]
+ [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))],
)
def test_plain_desc(self):
table = self.tables.some_table
- lx = table.c.x.label('lx')
+ lx = table.c.x.label("lx")
self._assert_result(
- select([lx]).order_by(lx.desc()),
- [(3, ), (2, ), (1, )]
+ select([lx]).order_by(lx.desc()), [(3,), (2,), (1,)]
)
def test_composed_int_desc(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
+ lx = (table.c.x + table.c.y).label("lx")
self._assert_result(
- select([lx]).order_by(lx.desc()),
- [(7, ), (5, ), (3, )]
+ select([lx]).order_by(lx.desc()), [(7,), (5,), (3,)]
)
@testing.requires.group_by_complex_expression
def test_group_by_composed(self):
table = self.tables.some_table
- expr = (table.c.x + table.c.y).label('lx')
- stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr)
- self._assert_result(
- stmt,
- [(1, 3), (1, 5), (1, 7)]
+ expr = (table.c.x + table.c.y).label("lx")
+ stmt = (
+ select([func.count(table.c.id), expr])
+ .group_by(expr)
+ .order_by(expr)
)
+ self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)])
class LimitOffsetTest(fixtures.TablesTest):
@@ -145,10 +135,13 @@ class LimitOffsetTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -159,20 +152,17 @@ class LimitOffsetTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_simple_limit(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2),
- [(1, 1, 2), (2, 2, 3)]
+ [(1, 1, 2), (2, 2, 3)],
)
@testing.requires.offset
@@ -180,7 +170,7 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).offset(2),
- [(3, 3, 4), (4, 4, 5)]
+ [(3, 3, 4), (4, 4, 5)],
)
@testing.requires.offset
@@ -188,7 +178,7 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2).offset(1),
- [(2, 2, 3), (3, 3, 4)]
+ [(2, 2, 3), (3, 3, 4)],
)
@testing.requires.offset
@@ -198,41 +188,40 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
stmt = select([table]).order_by(table.c.id).limit(2).offset(1)
sql = stmt.compile(
- dialect=config.db.dialect,
- compile_kwargs={"literal_binds": True})
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
sql = str(sql)
- self._assert_result(
- sql,
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(sql, [(2, 2, 3), (3, 3, 4)])
@testing.requires.bound_limit_offset
def test_bound_limit(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).limit(bindparam('l')),
+ select([table]).order_by(table.c.id).limit(bindparam("l")),
[(1, 1, 2), (2, 2, 3)],
- params={"l": 2}
+ params={"l": 2},
)
@testing.requires.bound_limit_offset
def test_bound_offset(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).offset(bindparam('o')),
+ select([table]).order_by(table.c.id).offset(bindparam("o")),
[(3, 3, 4), (4, 4, 5)],
- params={"o": 2}
+ params={"o": 2},
)
@testing.requires.bound_limit_offset
def test_bound_limit_offset(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).
- limit(bindparam("l")).offset(bindparam("o")),
+ select([table])
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
[(2, 2, 3), (3, 3, 4)],
- params={"l": 2, "o": 1}
+ params={"l": 2, "o": 1},
)
@@ -241,10 +230,13 @@ class CompoundSelectTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -255,14 +247,11 @@ class CompoundSelectTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_plain_union(self):
table = self.tables.some_table
@@ -270,10 +259,7 @@ class CompoundSelectTest(fixtures.TablesTest):
s2 = select([table]).where(table.c.id == 3)
u1 = union(s1, s2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
def test_select_from_plain_union(self):
table = self.tables.some_table
@@ -281,80 +267,88 @@ class CompoundSelectTest(fixtures.TablesTest):
s2 = select([table]).where(table.c.id == 3)
u1 = union(s1, s2).alias().select()
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.order_by_col_from_union
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id)
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ )
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.parens_in_union_contained_select_wo_limit_offset
def test_order_by_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- order_by(table.c.id)
+ s1 = select([table]).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select([table]).where(table.c.id == 3).order_by(table.c.id)
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
def test_distinct_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- distinct()
- s2 = select([table]).where(table.c.id == 3).\
- distinct()
+ s1 = select([table]).where(table.c.id == 2).distinct()
+ s2 = select([table]).where(table.c.id == 3).distinct()
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_in_unions_from_alias(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id)
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ )
# this necessarily has double parens
u1 = union(s1, s2).alias()
self._assert_result(
- u1.select().limit(2).order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
+ u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
def test_limit_offset_aliased_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id).alias().select()
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id).alias().select()
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
class ExpandingBoundInTest(fixtures.TablesTest):
@@ -362,11 +356,14 @@ class ExpandingBoundInTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer),
- Column('z', String(50)))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -377,178 +374,184 @@ class ExpandingBoundInTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3, "z": "z2"},
{"id": 3, "x": 3, "y": 4, "z": "z3"},
{"id": 4, "x": 4, "y": 5, "z": "z4"},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_multiple_empty_sets(self):
# test that any anonymous aliasing used by the dialect
# is fine with duplicates
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).where(
- table.c.y.in_(bindparam('p', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": [], "p": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .where(table.c.y.in_(bindparam("p", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": [], "p": []})
+
@testing.requires.tuple_in
def test_empty_heterogeneous_tuples(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.z).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
@testing.requires.tuple_in
def test_empty_homogeneous_tuples(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_bound_in_scalar(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(2, ), (3, ), (4, )],
- params={"q": [2, 3, 4]},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
+
@testing.requires.tuple_in
def test_bound_in_two_tuple(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
+ )
self._assert_result(
- stmt,
- [(2, ), (3, ), (4, )],
- params={"q": [(2, 3), (3, 4), (4, 5)]},
+ stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
)
@testing.requires.tuple_in
def test_bound_in_heterogeneous_two_tuple(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.z).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
+ )
self._assert_result(
stmt,
- [(2, ), (3, ), (4, )],
+ [(2,), (3,), (4,)],
params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
)
def test_empty_set_against_integer(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_empty_set_against_integer_negation(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.notin_(bindparam('q', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(1, ), (2, ), (3, ), (4, )],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.notin_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
def test_empty_set_against_string(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.z.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.z.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_empty_set_against_string_negation(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.z.notin_(bindparam('q', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(1, ), (2, ), (3, ), (4, )],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.z.notin_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
def test_null_in_empty_set_is_false(self):
- stmt = select([
- case(
- [
- (
- null().in_(bindparam('foo', value=(), expanding=True)),
- true()
- )
- ],
- else_=false()
- )
- ])
- in_(
- config.db.execute(stmt).fetchone()[0],
- (False, 0)
+ stmt = select(
+ [
+ case(
+ [
+ (
+ null().in_(
+ bindparam("foo", value=(), expanding=True)
+ ),
+ true(),
+ )
+ ],
+ else_=false(),
+ )
+ ]
)
+ in_(config.db.execute(stmt).fetchone()[0], (False, 0))
class LikeFunctionsTest(fixtures.TablesTest):
__backend__ = True
- run_inserts = 'once'
+ run_inserts = "once"
run_deletes = None
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -565,7 +568,7 @@ class LikeFunctionsTest(fixtures.TablesTest):
{"id": 8, "data": "ab9cdefg"},
{"id": 9, "data": "abcde#fg"},
{"id": 10, "data": "abcd9fg"},
- ]
+ ],
)
def _test(self, expr, expected):
@@ -573,8 +576,10 @@ class LikeFunctionsTest(fixtures.TablesTest):
with config.db.connect() as conn:
rows = {
- value for value, in
- conn.execute(select([some_table.c.id]).where(expr))
+ value
+ for value, in conn.execute(
+ select([some_table.c.id]).where(expr)
+ )
}
eq_(rows, expected)
@@ -591,7 +596,8 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(
col.startswith(literal_column("'ab%c'")),
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ )
def test_startswith_escape(self):
col = self.tables.some_table.c.data
@@ -608,8 +614,9 @@ class LikeFunctionsTest(fixtures.TablesTest):
def test_endswith_sqlexpr(self):
col = self.tables.some_table.c.data
- self._test(col.endswith(literal_column("'e%fg'")),
- {1, 2, 3, 4, 5, 6, 7, 8, 9})
+ self._test(
+ col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9}
+ )
def test_endswith_autoescape(self):
col = self.tables.some_table.c.data
@@ -640,5 +647,3 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
-
-
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index f1c00de6b..15a850fe9 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -9,140 +9,144 @@ from ..schema import Table, Column
class SequenceTest(fixtures.TablesTest):
- __requires__ = ('sequences',)
+ __requires__ = ("sequences",)
__backend__ = True
- run_create_tables = 'each'
+ run_create_tables = "each"
@classmethod
def define_tables(cls, metadata):
- Table('seq_pk', metadata,
- Column('id', Integer, Sequence('tab_id_seq'), primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "seq_pk",
+ metadata,
+ Column("id", Integer, Sequence("tab_id_seq"), primary_key=True),
+ Column("data", String(50)),
+ )
- Table('seq_opt_pk', metadata,
- Column('id', Integer, Sequence('tab_id_seq', optional=True),
- primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "seq_opt_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq", optional=True),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
def test_insert_roundtrip(self):
- config.db.execute(
- self.tables.seq_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.seq_pk.insert(), data="some data")
self._assert_round_trip(self.tables.seq_pk, config.db)
def test_insert_lastrowid(self):
- r = config.db.execute(
- self.tables.seq_pk.insert(),
- data="some data"
- )
- eq_(
- r.inserted_primary_key,
- [1]
- )
+ r = config.db.execute(self.tables.seq_pk.insert(), data="some data")
+ eq_(r.inserted_primary_key, [1])
def test_nextval_direct(self):
- r = config.db.execute(
- self.tables.seq_pk.c.id.default
- )
- eq_(
- r, 1
- )
+ r = config.db.execute(self.tables.seq_pk.c.id.default)
+ eq_(r, 1)
@requirements.sequences_optional
def test_optional_seq(self):
r = config.db.execute(
- self.tables.seq_opt_pk.insert(),
- data="some data"
- )
- eq_(
- r.inserted_primary_key,
- [1]
+ self.tables.seq_opt_pk.insert(), data="some data"
)
+ eq_(r.inserted_primary_key, [1])
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (1, "some data")
- )
+ eq_(row, (1, "some data"))
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
- __requires__ = ('sequences',)
+ __requires__ = ("sequences",)
__backend__ = True
def test_literal_binds_inline_compile(self):
table = Table(
- 'x', MetaData(),
- Column('y', Integer, Sequence('y_seq')),
- Column('q', Integer))
+ "x",
+ MetaData(),
+ Column("y", Integer, Sequence("y_seq")),
+ Column("q", Integer),
+ )
stmt = table.insert().values(q=5)
seq_nextval = testing.db.dialect.statement_compiler(
- statement=None, dialect=testing.db.dialect).visit_sequence(
- Sequence("y_seq"))
+ statement=None, dialect=testing.db.dialect
+ ).visit_sequence(Sequence("y_seq"))
self.assert_compile(
stmt,
- "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ),
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
literal_binds=True,
- dialect=testing.db.dialect)
+ dialect=testing.db.dialect,
+ )
class HasSequenceTest(fixtures.TestBase):
- __requires__ = 'sequences',
+ __requires__ = ("sequences",)
__backend__ = True
def test_has_sequence(self):
- s1 = Sequence('user_id_seq')
+ s1 = Sequence("user_id_seq")
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db,
- 'user_id_seq'), True)
+ eq_(
+ testing.db.dialect.has_sequence(testing.db, "user_id_seq"),
+ True,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_schema(self):
- s1 = Sequence('user_id_seq', schema=config.test_schema)
+ s1 = Sequence("user_id_seq", schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(
- testing.db, 'user_id_seq', schema=config.test_schema), True)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ True,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
def test_has_sequence_neg(self):
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
- False)
+ eq_(testing.db.dialect.has_sequence(testing.db, "user_id_seq"), False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self):
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
- schema=config.test_schema),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ False,
+ )
@testing.requires.schemas
def test_has_sequence_default_not_in_remote(self):
- s1 = Sequence('user_id_seq')
+ s1 = Sequence("user_id_seq")
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
- schema=config.test_schema),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ False,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self):
- s1 = Sequence('user_id_seq', schema=config.test_schema)
+ s1 = Sequence("user_id_seq", schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(testing.db, "user_id_seq"),
+ False,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 27c7bb115..6dfb80915 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -4,9 +4,24 @@ from .. import fixtures, config
from ..assertions import eq_
from ..config import requirements
from sqlalchemy import Integer, Unicode, UnicodeText, select, TIMESTAMP
-from sqlalchemy import Date, DateTime, Time, MetaData, String, \
- Text, Numeric, Float, literal, Boolean, cast, null, JSON, and_, \
- type_coerce, BigInteger
+from sqlalchemy import (
+ Date,
+ DateTime,
+ Time,
+ MetaData,
+ String,
+ Text,
+ Numeric,
+ Float,
+ literal,
+ Boolean,
+ cast,
+ null,
+ JSON,
+ and_,
+ type_coerce,
+ BigInteger,
+)
from ..schema import Table, Column
from ... import testing
import decimal
@@ -24,13 +39,17 @@ class _LiteralRoundTripFixture(object):
# into a typed column. we can then SELECT it back as its
# official type; ideally we'd be able to use CAST here
# but MySQL in particular can't CAST fully
- t = Table('t', self.metadata, Column('x', type_))
+ t = Table("t", self.metadata, Column("x", type_))
t.create()
for value in input_:
- ins = t.insert().values(x=literal(value)).compile(
- dialect=testing.db.dialect,
- compile_kwargs=dict(literal_binds=True)
+ ins = (
+ t.insert()
+ .values(x=literal(value))
+ .compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
)
testing.db.execute(ins)
@@ -42,40 +61,33 @@ class _LiteralRoundTripFixture(object):
class _UnicodeFixture(_LiteralRoundTripFixture):
- __requires__ = 'unicode_data',
+ __requires__ = ("unicode_data",)
- data = u("Alors vous imaginez ma surprise, au lever du jour, "
- "quand une drôle de petite voix m’a réveillé. Elle "
- "disait: « S’il vous plaît… dessine-moi un mouton! »")
+ data = u(
+ "Alors vous imaginez ma surprise, au lever du jour, "
+ "quand une drôle de petite voix m’a réveillé. Elle "
+ "disait: « S’il vous plaît… dessine-moi un mouton! »"
+ )
@classmethod
def define_tables(cls, metadata):
- Table('unicode_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('unicode_data', cls.datatype),
- )
+ Table(
+ "unicode_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("unicode_data", cls.datatype),
+ )
def test_round_trip(self):
unicode_table = self.tables.unicode_table
- config.db.execute(
- unicode_table.insert(),
- {
- 'unicode_data': self.data,
- }
- )
+ config.db.execute(unicode_table.insert(), {"unicode_data": self.data})
- row = config.db.execute(
- select([
- unicode_table.c.unicode_data,
- ])
- ).first()
+ row = config.db.execute(select([unicode_table.c.unicode_data])).first()
- eq_(
- row,
- (self.data, )
- )
+ eq_(row, (self.data,))
assert isinstance(row[0], util.text_type)
def test_round_trip_executemany(self):
@@ -83,44 +95,29 @@ class _UnicodeFixture(_LiteralRoundTripFixture):
config.db.execute(
unicode_table.insert(),
- [
- {
- 'unicode_data': self.data,
- }
- for i in range(3)
- ]
+ [{"unicode_data": self.data} for i in range(3)],
)
rows = config.db.execute(
- select([
- unicode_table.c.unicode_data,
- ])
+ select([unicode_table.c.unicode_data])
).fetchall()
- eq_(
- rows,
- [(self.data, ) for i in range(3)]
- )
+ eq_(rows, [(self.data,) for i in range(3)])
for row in rows:
assert isinstance(row[0], util.text_type)
def _test_empty_strings(self):
unicode_table = self.tables.unicode_table
- config.db.execute(
- unicode_table.insert(),
- {"unicode_data": u('')}
- )
- row = config.db.execute(
- select([unicode_table.c.unicode_data])
- ).first()
- eq_(row, (u(''),))
+ config.db.execute(unicode_table.insert(), {"unicode_data": u("")})
+ row = config.db.execute(select([unicode_table.c.unicode_data])).first()
+ eq_(row, (u(""),))
def test_literal(self):
self._literal_round_trip(self.datatype, [self.data], [self.data])
class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
- __requires__ = 'unicode_data',
+ __requires__ = ("unicode_data",)
__backend__ = True
datatype = Unicode(255)
@@ -131,7 +128,7 @@ class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
- __requires__ = 'unicode_data', 'text_type'
+ __requires__ = "unicode_data", "text_type"
__backend__ = True
datatype = UnicodeText()
@@ -142,54 +139,47 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
- __requires__ = 'text_type',
+ __requires__ = ("text_type",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('text_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('text_data', Text),
- )
+ Table(
+ "text_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("text_data", Text),
+ )
def test_text_roundtrip(self):
text_table = self.tables.text_table
- config.db.execute(
- text_table.insert(),
- {"text_data": 'some text'}
- )
- row = config.db.execute(
- select([text_table.c.text_data])
- ).first()
- eq_(row, ('some text',))
+ config.db.execute(text_table.insert(), {"text_data": "some text"})
+ row = config.db.execute(select([text_table.c.text_data])).first()
+ eq_(row, ("some text",))
def test_text_empty_strings(self):
text_table = self.tables.text_table
- config.db.execute(
- text_table.insert(),
- {"text_data": ''}
- )
- row = config.db.execute(
- select([text_table.c.text_data])
- ).first()
- eq_(row, ('',))
+ config.db.execute(text_table.insert(), {"text_data": ""})
+ row = config.db.execute(select([text_table.c.text_data])).first()
+ eq_(row, ("",))
def test_literal(self):
self._literal_round_trip(Text, ["some text"], ["some text"])
def test_literal_quoting(self):
- data = '''some 'text' hey "hi there" that's text'''
+ data = """some 'text' hey "hi there" that's text"""
self._literal_round_trip(Text, [data], [data])
def test_literal_backslashes(self):
- data = r'backslash one \ backslash two \\ end'
+ data = r"backslash one \ backslash two \\ end"
self._literal_round_trip(Text, [data], [data])
def test_literal_percentsigns(self):
- data = r'percent % signs %% percent'
+ data = r"percent % signs %% percent"
self._literal_round_trip(Text, [data], [data])
@@ -199,9 +189,7 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
@requirements.unbounded_varchar
def test_nolength_string(self):
metadata = MetaData()
- foo = Table('foo', metadata,
- Column('one', String)
- )
+ foo = Table("foo", metadata, Column("one", String))
foo.create(config.db)
foo.drop(config.db)
@@ -210,11 +198,11 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._literal_round_trip(String(40), ["some text"], ["some text"])
def test_literal_quoting(self):
- data = '''some 'text' hey "hi there" that's text'''
+ data = """some 'text' hey "hi there" that's text"""
self._literal_round_trip(String(40), [data], [data])
def test_literal_backslashes(self):
- data = r'backslash one \ backslash two \\ end'
+ data = r"backslash one \ backslash two \\ end"
self._literal_round_trip(String(40), [data], [data])
@@ -223,44 +211,32 @@ class _DateFixture(_LiteralRoundTripFixture):
@classmethod
def define_tables(cls, metadata):
- Table('date_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('date_data', cls.datatype),
- )
+ Table(
+ "date_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("date_data", cls.datatype),
+ )
def test_round_trip(self):
date_table = self.tables.date_table
- config.db.execute(
- date_table.insert(),
- {'date_data': self.data}
- )
+ config.db.execute(date_table.insert(), {"date_data": self.data})
- row = config.db.execute(
- select([
- date_table.c.date_data,
- ])
- ).first()
+ row = config.db.execute(select([date_table.c.date_data])).first()
compare = self.compare or self.data
- eq_(row,
- (compare, ))
+ eq_(row, (compare,))
assert isinstance(row[0], type(compare))
def test_null(self):
date_table = self.tables.date_table
- config.db.execute(
- date_table.insert(),
- {'date_data': None}
- )
+ config.db.execute(date_table.insert(), {"date_data": None})
- row = config.db.execute(
- select([
- date_table.c.date_data,
- ])
- ).first()
+ row = config.db.execute(select([date_table.c.date_data])).first()
eq_(row, (None,))
@testing.requires.datetime_literals
@@ -270,48 +246,49 @@ class _DateFixture(_LiteralRoundTripFixture):
class DateTimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime',
+ __requires__ = ("datetime",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime_microseconds',
+ __requires__ = ("datetime_microseconds",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'timestamp_microseconds',
+ __requires__ = ("timestamp_microseconds",)
__backend__ = True
datatype = TIMESTAMP
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
class TimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'time',
+ __requires__ = ("time",)
__backend__ = True
datatype = Time
data = datetime.time(12, 57, 18)
class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'time_microseconds',
+ __requires__ = ("time_microseconds",)
__backend__ = True
datatype = Time
data = datetime.time(12, 57, 18, 396)
class DateTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date',
+ __requires__ = ("date",)
__backend__ = True
datatype = Date
data = datetime.date(2012, 10, 15)
class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date', 'date_coerces_from_datetime'
+ __requires__ = "date", "date_coerces_from_datetime"
__backend__ = True
datatype = Date
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
@@ -319,14 +296,14 @@ class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime_historic',
+ __requires__ = ("datetime_historic",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(1850, 11, 10, 11, 52, 35)
class DateHistoricTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date_historic',
+ __requires__ = ("date_historic",)
__backend__ = True
datatype = Date
data = datetime.date(1727, 4, 1)
@@ -345,26 +322,21 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
def _round_trip(self, datatype, data):
metadata = self.metadata
int_table = Table(
- 'integer_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('integer_data', datatype),
+ "integer_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("integer_data", datatype),
)
metadata.create_all(config.db)
- config.db.execute(
- int_table.insert(),
- {'integer_data': data}
- )
+ config.db.execute(int_table.insert(), {"integer_data": data})
- row = config.db.execute(
- select([
- int_table.c.integer_data,
- ])
- ).first()
+ row = config.db.execute(select([int_table.c.integer_data])).first()
- eq_(row, (data, ))
+ eq_(row, (data,))
if util.py3k:
assert isinstance(row[0], int)
@@ -377,12 +349,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
@testing.provide_metadata
- def _do_test(self, type_, input_, output,
- filter_=None, check_scale=False):
+ def _do_test(self, type_, input_, output, filter_=None, check_scale=False):
metadata = self.metadata
- t = Table('t', metadata, Column('x', type_))
+ t = Table("t", metadata, Column("x", type_))
t.create()
- t.insert().execute([{'x': x} for x in input_])
+ t.insert().execute([{"x": x} for x in input_])
result = {row[0] for row in t.select().execute()}
output = set(output)
@@ -391,10 +362,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
output = set(filter_(x) for x in output)
eq_(result, output)
if check_scale:
- eq_(
- [str(x) for x in result],
- [str(x) for x in output],
- )
+ eq_([str(x) for x in result], [str(x) for x in output])
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
def test_render_literal_numeric(self):
@@ -416,8 +384,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._literal_round_trip(
Float(4),
[15.7563, decimal.Decimal("15.7563")],
- [15.7563, ],
- filter_=lambda n: n is not None and round(n, 5) or None
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
)
@testing.requires.precision_generic_float_type
@@ -425,8 +393,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._do_test(
Float(None, decimal_return_scale=7, asdecimal=True),
[15.7563827, decimal.Decimal("15.7563827")],
- [decimal.Decimal("15.7563827"), ],
- check_scale=True
+ [decimal.Decimal("15.7563827")],
+ check_scale=True,
)
def test_numeric_as_decimal(self):
@@ -445,18 +413,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
@testing.requires.fetch_null_from_numeric
def test_numeric_null_as_decimal(self):
- self._do_test(
- Numeric(precision=8, scale=4),
- [None],
- [None],
- )
+ self._do_test(Numeric(precision=8, scale=4), [None], [None])
@testing.requires.fetch_null_from_numeric
def test_numeric_null_as_float(self):
self._do_test(
- Numeric(precision=8, scale=4, asdecimal=False),
- [None],
- [None],
+ Numeric(precision=8, scale=4, asdecimal=False), [None], [None]
)
@testing.requires.floats_to_four_decimals
@@ -472,15 +434,13 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
Float(precision=8),
[15.7563, decimal.Decimal("15.7563")],
[15.7563],
- filter_=lambda n: n is not None and round(n, 5) or None
+ filter_=lambda n: n is not None and round(n, 5) or None,
)
def test_float_coerce_round_trip(self):
expr = 15.7563
- val = testing.db.scalar(
- select([literal(expr)])
- )
+ val = testing.db.scalar(select([literal(expr)]))
eq_(val, expr)
# this does not work in MySQL, see #4036, however we choose not
@@ -491,34 +451,28 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
def test_decimal_coerce_round_trip(self):
expr = decimal.Decimal("15.7563")
- val = testing.db.scalar(
- select([literal(expr)])
- )
+ val = testing.db.scalar(select([literal(expr)]))
eq_(val, expr)
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
def test_decimal_coerce_round_trip_w_cast(self):
expr = decimal.Decimal("15.7563")
- val = testing.db.scalar(
- select([cast(expr, Numeric(10, 4))])
- )
+ val = testing.db.scalar(select([cast(expr, Numeric(10, 4))]))
eq_(val, expr)
@testing.requires.precision_numerics_general
def test_precision_decimal(self):
- numbers = set([
- decimal.Decimal("54.234246451650"),
- decimal.Decimal("0.004354"),
- decimal.Decimal("900.0"),
- ])
-
- self._do_test(
- Numeric(precision=18, scale=12),
- numbers,
- numbers,
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ]
)
+ self._do_test(Numeric(precision=18, scale=12), numbers, numbers)
+
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal(self):
"""test exceedingly small decimals.
@@ -528,25 +482,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
"""
- numbers = set([
- decimal.Decimal('1E-2'),
- decimal.Decimal('1E-3'),
- decimal.Decimal('1E-4'),
- decimal.Decimal('1E-5'),
- decimal.Decimal('1E-6'),
- decimal.Decimal('1E-7'),
- decimal.Decimal('1E-8'),
- decimal.Decimal("0.01000005940696"),
- decimal.Decimal("0.00000005940696"),
- decimal.Decimal("0.00000000000696"),
- decimal.Decimal("0.70000000000696"),
- decimal.Decimal("696E-12"),
- ])
- self._do_test(
- Numeric(precision=18, scale=14),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("1E-2"),
+ decimal.Decimal("1E-3"),
+ decimal.Decimal("1E-4"),
+ decimal.Decimal("1E-5"),
+ decimal.Decimal("1E-6"),
+ decimal.Decimal("1E-7"),
+ decimal.Decimal("1E-8"),
+ decimal.Decimal("0.01000005940696"),
+ decimal.Decimal("0.00000005940696"),
+ decimal.Decimal("0.00000000000696"),
+ decimal.Decimal("0.70000000000696"),
+ decimal.Decimal("696E-12"),
+ ]
)
+ self._do_test(Numeric(precision=18, scale=14), numbers, numbers)
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal_large(self):
@@ -554,41 +506,32 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
"""
- numbers = set([
- decimal.Decimal('4E+8'),
- decimal.Decimal("5748E+15"),
- decimal.Decimal('1.521E+15'),
- decimal.Decimal('00000000000000.1E+12'),
- ])
- self._do_test(
- Numeric(precision=25, scale=2),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("4E+8"),
+ decimal.Decimal("5748E+15"),
+ decimal.Decimal("1.521E+15"),
+ decimal.Decimal("00000000000000.1E+12"),
+ ]
)
+ self._do_test(Numeric(precision=25, scale=2), numbers, numbers)
@testing.requires.precision_numerics_many_significant_digits
def test_many_significant_digits(self):
- numbers = set([
- decimal.Decimal("31943874831932418390.01"),
- decimal.Decimal("319438950232418390.273596"),
- decimal.Decimal("87673.594069654243"),
- ])
- self._do_test(
- Numeric(precision=38, scale=12),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("31943874831932418390.01"),
+ decimal.Decimal("319438950232418390.273596"),
+ decimal.Decimal("87673.594069654243"),
+ ]
)
+ self._do_test(Numeric(precision=38, scale=12), numbers, numbers)
@testing.requires.precision_numerics_retains_significant_digits
def test_numeric_no_decimal(self):
- numbers = set([
- decimal.Decimal("1.000")
- ])
+ numbers = set([decimal.Decimal("1.000")])
self._do_test(
- Numeric(precision=5, scale=3),
- numbers,
- numbers,
- check_scale=True
+ Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
)
@@ -597,42 +540,32 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('boolean_table', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('value', Boolean),
- Column('unconstrained_value', Boolean(create_constraint=False)),
- )
+ Table(
+ "boolean_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("value", Boolean),
+ Column("unconstrained_value", Boolean(create_constraint=False)),
+ )
def test_render_literal_bool(self):
- self._literal_round_trip(
- Boolean(),
- [True, False],
- [True, False]
- )
+ self._literal_round_trip(Boolean(), [True, False], [True, False])
def test_round_trip(self):
boolean_table = self.tables.boolean_table
config.db.execute(
boolean_table.insert(),
- {
- 'id': 1,
- 'value': True,
- 'unconstrained_value': False
- }
+ {"id": 1, "value": True, "unconstrained_value": False},
)
row = config.db.execute(
- select([
- boolean_table.c.value,
- boolean_table.c.unconstrained_value
- ])
+ select(
+ [boolean_table.c.value, boolean_table.c.unconstrained_value]
+ )
).first()
- eq_(
- row,
- (True, False)
- )
+ eq_(row, (True, False))
assert isinstance(row[0], bool)
def test_null(self):
@@ -640,24 +573,16 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
config.db.execute(
boolean_table.insert(),
- {
- 'id': 1,
- 'value': None,
- 'unconstrained_value': None
- }
+ {"id": 1, "value": None, "unconstrained_value": None},
)
row = config.db.execute(
- select([
- boolean_table.c.value,
- boolean_table.c.unconstrained_value
- ])
+ select(
+ [boolean_table.c.value, boolean_table.c.unconstrained_value]
+ )
).first()
- eq_(
- row,
- (None, None)
- )
+ eq_(row, (None, None))
def test_whereclause(self):
# testing "WHERE <column>" renders a compatible expression
@@ -667,92 +592,82 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
conn.execute(
boolean_table.insert(),
[
- {'id': 1, 'value': True, 'unconstrained_value': True},
- {'id': 2, 'value': False, 'unconstrained_value': False}
- ]
+ {"id": 1, "value": True, "unconstrained_value": True},
+ {"id": 2, "value": False, "unconstrained_value": False},
+ ],
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(boolean_table.c.value)
),
- 1
+ 1,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(
- boolean_table.c.unconstrained_value)
+ boolean_table.c.unconstrained_value
+ )
),
- 1
+ 1,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(~boolean_table.c.value)
),
- 2
+ 2,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(
- ~boolean_table.c.unconstrained_value)
+ ~boolean_table.c.unconstrained_value
+ )
),
- 2
+ 2,
)
-
-
class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
- __requires__ = 'json_type',
+ __requires__ = ("json_type",)
__backend__ = True
datatype = JSON
- data1 = {
- "key1": "value1",
- "key2": "value2"
- }
+ data1 = {"key1": "value1", "key2": "value2"}
data2 = {
"Key 'One'": "value1",
"key two": "value2",
- "key three": "value ' three '"
+ "key three": "value ' three '",
}
data3 = {
"key1": [1, 2, 3],
"key2": ["one", "two", "three"],
- "key3": [{"four": "five"}, {"six": "seven"}]
+ "key3": [{"four": "five"}, {"six": "seven"}],
}
data4 = ["one", "two", "three"]
data5 = {
"nested": {
- "elem1": [
- {"a": "b", "c": "d"},
- {"e": "f", "g": "h"}
- ],
- "elem2": {
- "elem3": {"elem4": "elem5"}
- }
+ "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}],
+ "elem2": {"elem3": {"elem4": "elem5"}},
}
}
- data6 = {
- "a": 5,
- "b": "some value",
- "c": {"foo": "bar"}
- }
+ data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}}
@classmethod
def define_tables(cls, metadata):
- Table('data_table', metadata,
- Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False),
- Column('data', cls.datatype),
- Column('nulldata', cls.datatype(none_as_null=True))
- )
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
def test_round_trip_data1(self):
self._test_round_trip(self.data1)
@@ -761,99 +676,82 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
data_table = self.tables.data_table
config.db.execute(
- data_table.insert(),
- {'name': 'row1', 'data': data_element}
+ data_table.insert(), {"name": "row1", "data": data_element}
)
- row = config.db.execute(
- select([
- data_table.c.data,
- ])
- ).first()
+ row = config.db.execute(select([data_table.c.data])).first()
- eq_(row, (data_element, ))
+ eq_(row, (data_element,))
def test_round_trip_none_as_sql_null(self):
- col = self.tables.data_table.c['nulldata']
+ col = self.tables.data_table.c["nulldata"]
with config.db.connect() as conn:
conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": None}
+ self.tables.data_table.insert(), {"name": "r1", "data": None}
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(col.is_(null()))
+ select([self.tables.data_table.c.name]).where(
+ col.is_(null())
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def test_round_trip_json_null_as_json_null(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
with config.db.connect() as conn:
conn.execute(
self.tables.data_table.insert(),
- {"name": "r1", "data": JSON.NULL}
+ {"name": "r1", "data": JSON.NULL},
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(cast(col, String) == 'null')
+ select([self.tables.data_table.c.name]).where(
+ cast(col, String) == "null"
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def test_round_trip_none_as_json_null(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
with config.db.connect() as conn:
conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": None}
+ self.tables.data_table.insert(), {"name": "r1", "data": None}
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(cast(col, String) == 'null')
+ select([self.tables.data_table.c.name]).where(
+ cast(col, String) == "null"
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def _criteria_fixture(self):
config.db.execute(
self.tables.data_table.insert(),
- [{"name": "r1", "data": self.data1},
- {"name": "r2", "data": self.data2},
- {"name": "r3", "data": self.data3},
- {"name": "r4", "data": self.data4},
- {"name": "r5", "data": self.data5},
- {"name": "r6", "data": self.data6}]
+ [
+ {"name": "r1", "data": self.data1},
+ {"name": "r2", "data": self.data2},
+ {"name": "r3", "data": self.data3},
+ {"name": "r4", "data": self.data4},
+ {"name": "r5", "data": self.data5},
+ {"name": "r6", "data": self.data6},
+ ],
)
def _test_index_criteria(self, crit, expected, test_literal=True):
@@ -861,20 +759,20 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
with config.db.connect() as conn:
stmt = select([self.tables.data_table.c.name]).where(crit)
- eq_(
- conn.scalar(stmt),
- expected
- )
+ eq_(conn.scalar(stmt), expected)
if test_literal:
- literal_sql = str(stmt.compile(
- config.db, compile_kwargs={"literal_binds": True}))
+ literal_sql = str(
+ stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}
+ )
+ )
eq_(conn.scalar(literal_sql), expected)
def test_crit_spaces_in_key(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
# limit the rows here to avoid PG error
# "cannot extract field from a non-object", which is
@@ -882,76 +780,74 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
self._test_index_criteria(
and_(
name.in_(["r1", "r2", "r3"]),
- cast(col["key two"], String) == '"value2"'
+ cast(col["key two"], String) == '"value2"',
),
- "r2"
+ "r2",
)
@config.requirements.json_array_indexes
def test_crit_simple_int(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
# limit the rows here to avoid PG error
# "cannot extract array element from a non-array", which is
# fixed in 9.4 but may exist in 9.3
self._test_index_criteria(
- and_(name == 'r4', cast(col[1], String) == '"two"'),
- "r4"
+ and_(name == "r4", cast(col[1], String) == '"two"'), "r4"
)
def test_crit_mixed_path(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- cast(col[("key3", 1, "six")], String) == '"seven"',
- "r3"
+ cast(col[("key3", 1, "six")], String) == '"seven"', "r3"
)
def test_crit_string_path(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
cast(col[("nested", "elem2", "elem3", "elem4")], String)
== '"elem5"',
- "r5"
+ "r5",
)
def test_crit_against_string_basic(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["b"], String) == '"some value"'),
- "r6"
+ and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6"
)
def test_crit_against_string_coerce_type(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6',
- cast(col["b"], String) == type_coerce("some value", JSON)),
+ and_(
+ name == "r6",
+ cast(col["b"], String) == type_coerce("some value", JSON),
+ ),
"r6",
- test_literal=False
+ test_literal=False,
)
def test_crit_against_int_basic(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["a"], String) == '5'),
- "r6"
+ and_(name == "r6", cast(col["a"], String) == "5"), "r6"
)
def test_crit_against_int_coerce_type(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["a"], String) == type_coerce(5, JSON)),
+ and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)),
"r6",
- test_literal=False
+ test_literal=False,
)
def test_unicode_round_trip(self):
@@ -961,17 +857,17 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
{
"name": "r1",
"data": {
- util.u('réveillé'): util.u('réveillé'),
- "data": {"k1": util.u('drôle')}
- }
- }
+ util.u("réveillé"): util.u("réveillé"),
+ "data": {"k1": util.u("drôle")},
+ },
+ },
)
eq_(
conn.scalar(select([self.tables.data_table.c.data])),
{
- util.u('réveillé'): util.u('réveillé'),
- "data": {"k1": util.u('drôle')}
+ util.u("réveillé"): util.u("réveillé"),
+ "data": {"k1": util.u("drôle")},
},
)
@@ -986,7 +882,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
s = Session(testing.db)
- d1 = Data(name='d1', data=None, nulldata=None)
+ d1 = Data(name="d1", data=None, nulldata=None)
s.add(d1)
s.commit()
@@ -995,24 +891,46 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
)
eq_(
s.query(
- cast(self.tables.data_table.c.data, String(convert_unicode="force")),
- cast(self.tables.data_table.c.nulldata, String)
- ).filter(self.tables.data_table.c.name == 'd1').first(),
- ("null", None)
+ cast(
+ self.tables.data_table.c.data,
+ String(convert_unicode="force"),
+ ),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
)
eq_(
s.query(
- cast(self.tables.data_table.c.data, String(convert_unicode="force")),
- cast(self.tables.data_table.c.nulldata, String)
- ).filter(self.tables.data_table.c.name == 'd2').first(),
- ("null", None)
- )
-
-
-__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'JSONTest',
- 'DateTest', 'DateTimeTest', 'TextTest',
- 'NumericTest', 'IntegerTest',
- 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest',
- 'TimeMicrosecondsTest', 'TimestampMicrosecondsTest', 'TimeTest',
- 'DateTimeMicrosecondsTest',
- 'DateHistoricTest', 'StringTest', 'BooleanTest')
+ cast(
+ self.tables.data_table.c.data,
+ String(convert_unicode="force"),
+ ),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
+ )
+
+
+__all__ = (
+ "UnicodeVarcharTest",
+ "UnicodeTextTest",
+ "JSONTest",
+ "DateTest",
+ "DateTimeTest",
+ "TextTest",
+ "NumericTest",
+ "IntegerTest",
+ "DateTimeHistoricTest",
+ "DateTimeCoercedToDateTimeTest",
+ "TimeMicrosecondsTest",
+ "TimestampMicrosecondsTest",
+ "TimeTest",
+ "DateTimeMicrosecondsTest",
+ "DateHistoricTest",
+ "StringTest",
+ "BooleanTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
index e4c61e74a..b232c3a78 100644
--- a/lib/sqlalchemy/testing/suite/test_update_delete.py
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -6,15 +6,17 @@ from ..schema import Table, Column
class SimpleUpdateDeleteTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('plain_pk', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -24,40 +26,29 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest):
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
- ]
+ ],
)
def test_update(self):
t = self.tables.plain_pk
- r = config.db.execute(
- t.update().where(t.c.id == 2),
- data="d2_new"
- )
+ r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new")
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
- [
- (1, "d1"),
- (2, "d2_new"),
- (3, "d3")
- ]
+ [(1, "d1"), (2, "d2_new"), (3, "d3")],
)
def test_delete(self):
t = self.tables.plain_pk
- r = config.db.execute(
- t.delete().where(t.c.id == 2)
- )
+ r = config.db.execute(t.delete().where(t.c.id == 2))
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
- [
- (1, "d1"),
- (3, "d3")
- ]
+ [(1, "d1"), (3, "d3")],
)
-__all__ = ('SimpleUpdateDeleteTest', )
+
+__all__ = ("SimpleUpdateDeleteTest",)
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index 409d3bda5..5b015d214 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -14,6 +14,7 @@ import sys
import types
if jython:
+
def jython_gc_collect(*args):
"""aggressive gc.collect for tests."""
gc.collect()
@@ -25,9 +26,11 @@ if jython:
# "lazy" gc, for VM's that don't GC on refcount == 0
gc_collect = lazy_gc = jython_gc_collect
elif pypy:
+
def pypy_gc_collect(*args):
gc.collect()
gc.collect()
+
gc_collect = lazy_gc = pypy_gc_collect
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
@@ -42,11 +45,13 @@ def picklers():
if py2k:
try:
import cPickle
+
picklers.add(cPickle)
except ImportError:
pass
import pickle
+
picklers.add(pickle)
# yes, this thing needs this much testing
@@ -60,9 +65,9 @@ def round_decimal(value, prec):
return round(value, prec)
# can also use shift() here but that is 2.6 only
- return (value * decimal.Decimal("1" + "0" * prec)
- ).to_integral(decimal.ROUND_FLOOR) / \
- pow(10, prec)
+ return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
+ decimal.ROUND_FLOOR
+ ) / pow(10, prec)
class RandomSet(set):
@@ -137,8 +142,9 @@ def function_named(fn, name):
try:
fn.__name__ = name
except TypeError:
- fn = types.FunctionType(fn.__code__, fn.__globals__, name,
- fn.__defaults__, fn.__closure__)
+ fn = types.FunctionType(
+ fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
+ )
return fn
@@ -190,7 +196,7 @@ def provide_metadata(fn, *args, **kw):
metadata = schema.MetaData(config.db)
self = args[0]
- prev_meta = getattr(self, 'metadata', None)
+ prev_meta = getattr(self, "metadata", None)
self.metadata = metadata
try:
return fn(*args, **kw)
@@ -213,8 +219,8 @@ def force_drop_names(*names):
try:
return fn(*args, **kw)
finally:
- drop_all_tables(
- config.db, inspect(config.db), include_names=names)
+ drop_all_tables(config.db, inspect(config.db), include_names=names)
+
return go
@@ -234,8 +240,13 @@ class adict(dict):
def drop_all_tables(engine, inspector, schema=None, include_names=None):
- from sqlalchemy import Column, Table, Integer, MetaData, \
- ForeignKeyConstraint
+ from sqlalchemy import (
+ Column,
+ Table,
+ Integer,
+ MetaData,
+ ForeignKeyConstraint,
+ )
from sqlalchemy.schema import DropTable, DropConstraint
if include_names is not None:
@@ -243,30 +254,35 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None):
with engine.connect() as conn:
for tname, fkcs in reversed(
- inspector.get_sorted_table_and_fkc_names(schema=schema)):
+ inspector.get_sorted_table_and_fkc_names(schema=schema)
+ ):
if tname:
if include_names is not None and tname not in include_names:
continue
- conn.execute(DropTable(
- Table(tname, MetaData(), schema=schema)
- ))
+ conn.execute(
+ DropTable(Table(tname, MetaData(), schema=schema))
+ )
elif fkcs:
if not engine.dialect.supports_alter:
continue
for tname, fkc in fkcs:
- if include_names is not None and \
- tname not in include_names:
+ if (
+ include_names is not None
+ and tname not in include_names
+ ):
continue
tb = Table(
- tname, MetaData(),
- Column('x', Integer),
- Column('y', Integer),
- schema=schema
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
+ )
)
- conn.execute(DropConstraint(
- ForeignKeyConstraint(
- [tb.c.x], [tb.c.y], name=fkc)
- ))
def teardown_events(event_cls):
@@ -276,5 +292,5 @@ def teardown_events(event_cls):
return fn(*arg, **kw)
finally:
event_cls._clear()
- return decorate
+ return decorate
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
index 46e7c54db..e0101b14d 100644
--- a/lib/sqlalchemy/testing/warnings.py
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -15,17 +15,20 @@ from . import assertions
def setup_filters():
"""Set global warning behavior for the test suite."""
- warnings.filterwarnings('ignore',
- category=sa_exc.SAPendingDeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SAWarning)
+ warnings.filterwarnings(
+ "ignore", category=sa_exc.SAPendingDeprecationWarning
+ )
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
# some selected deprecations...
- warnings.filterwarnings('error', category=DeprecationWarning)
+ warnings.filterwarnings("error", category=DeprecationWarning)
warnings.filterwarnings(
- "ignore", category=DeprecationWarning, message=".*StopIteration")
+ "ignore", category=DeprecationWarning, message=".*StopIteration"
+ )
warnings.filterwarnings(
- "ignore", category=DeprecationWarning, message=".*inspect.getargspec")
+ "ignore", category=DeprecationWarning, message=".*inspect.getargspec"
+ )
def assert_warnings(fn, warning_msgs, regex=False):
@@ -36,6 +39,6 @@ def assert_warnings(fn, warning_msgs, regex=False):
"""
with assertions._expect_warnings(
- sa_exc.SAWarning, warning_msgs, regex=regex):
+ sa_exc.SAWarning, warning_msgs, regex=regex
+ ):
return fn()
-