summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-01-06 01:14:26 -0500
committermike bayer <mike_mp@zzzcomputing.com>2019-01-06 17:34:50 +0000
commit1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch)
tree28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/testing
parent404e69426b05a82d905cbb3ad33adafccddb00dd (diff)
downloadsqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits applied at all. The black run will format code consistently, however in some cases that are prevalent in SQLAlchemy code it produces too-long lines. The too-long lines will be resolved in the following commit that will resolve all remaining flake8 issues including shadowed builtins, long lines, import order, unused imports, duplicate imports, and docstring issues. Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/__init__.py61
-rw-r--r--lib/sqlalchemy/testing/assertions.py182
-rw-r--r--lib/sqlalchemy/testing/assertsql.py148
-rw-r--r--lib/sqlalchemy/testing/config.py6
-rw-r--r--lib/sqlalchemy/testing/engines.py60
-rw-r--r--lib/sqlalchemy/testing/entities.py23
-rw-r--r--lib/sqlalchemy/testing/exclusions.py123
-rw-r--r--lib/sqlalchemy/testing/fixtures.py85
-rw-r--r--lib/sqlalchemy/testing/mock.py3
-rw-r--r--lib/sqlalchemy/testing/pickleable.py36
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py7
-rw-r--r--lib/sqlalchemy/testing/plugin/noseplugin.py23
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py359
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py104
-rw-r--r--lib/sqlalchemy/testing/profiling.py71
-rw-r--r--lib/sqlalchemy/testing/provision.py86
-rw-r--r--lib/sqlalchemy/testing/replay_fixture.py81
-rw-r--r--lib/sqlalchemy/testing/requirements.py78
-rw-r--r--lib/sqlalchemy/testing/runner.py2
-rw-r--r--lib/sqlalchemy/testing/schema.py68
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py132
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py50
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py54
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py221
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py762
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py242
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py483
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py136
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py702
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py37
-rw-r--r--lib/sqlalchemy/testing/util.py66
-rw-r--r--lib/sqlalchemy/testing/warnings.py21
33 files changed, 2399 insertions, 2114 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
index 413a492b8..f46ca4528 100644
--- a/lib/sqlalchemy/testing/__init__.py
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -10,23 +10,62 @@ from .warnings import assert_warnings
from . import config
-from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
- fails_on, fails_on_everything_except, skip, only_on, exclude, \
- against as _against, _server_version, only_if, fails
+from .exclusions import (
+ db_spec,
+ _is_excluded,
+ fails_if,
+ skip_if,
+ future,
+ fails_on,
+ fails_on_everything_except,
+ skip,
+ only_on,
+ exclude,
+ against as _against,
+ _server_version,
+ only_if,
+ fails,
+)
def against(*queries):
return _against(config._current, *queries)
-from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
- eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \
- assert_raises_message, AssertsCompiledSQL, ComparesTables, \
- AssertsExecutionResults, expect_deprecated, expect_warnings, \
- in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false
-from .util import run_as_contextmanager, rowset, fail, \
- provide_metadata, adict, force_drop_names, \
- teardown_events
+from .assertions import (
+ emits_warning,
+ emits_warning_on,
+ uses_deprecated,
+ eq_,
+ ne_,
+ le_,
+ is_,
+ is_not_,
+ startswith_,
+ assert_raises,
+ assert_raises_message,
+ AssertsCompiledSQL,
+ ComparesTables,
+ AssertsExecutionResults,
+ expect_deprecated,
+ expect_warnings,
+ in_,
+ not_in_,
+ eq_ignore_whitespace,
+ eq_regex,
+ is_true,
+ is_false,
+)
+
+from .util import (
+ run_as_contextmanager,
+ rowset,
+ fail,
+ provide_metadata,
+ adict,
+ force_drop_names,
+ teardown_events,
+)
crashes = skip
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e42376921..73ab4556a 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -86,6 +86,7 @@ def emits_warning_on(db, *messages):
were in fact seen.
"""
+
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
@@ -114,12 +115,14 @@ def uses_deprecated(*messages):
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
+
return decorate
@contextlib.contextmanager
-def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
- py2konly=False):
+def _expect_warnings(
+ exc_cls, messages, regex=True, assert_=True, py2konly=False
+):
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
@@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
return
for filter_ in filters:
- if (regex and filter_.match(msg)) or \
- (not regex and filter_ == msg):
+ if (regex and filter_.match(msg)) or (
+ not regex and filter_ == msg
+ ):
seen.discard(filter_)
break
else:
@@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True,
yield
if assert_ and (not py2konly or not compat.py3k):
- assert not seen, "Warnings were not seen: %s" % \
- ", ".join("%r" % (s.pattern if regex else s) for s in seen)
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
def global_cleanup_assertions():
@@ -170,6 +175,7 @@ def global_cleanup_assertions():
"""
_assert_no_stray_pool_connections()
+
_STRAY_CONNECTION_FAILURES = 0
@@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections():
# OK, let's be somewhat forgiving.
_STRAY_CONNECTION_FAILURES += 1
- print("Encountered a stray connection in test cleanup: %s"
- % str(pool._refs))
+ print(
+ "Encountered a stray connection in test cleanup: %s"
+ % str(pool._refs)
+ )
# then do a real GC sweep. We shouldn't even be here
# so a single sweep should really be doing it, otherwise
# there's probably a real unreachable cycle somewhere.
@@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections():
pool._refs.clear()
_STRAY_CONNECTION_FAILURES = 0
warnings.warn(
- "Stray connection refused to leave "
- "after gc.collect(): %s" % err)
+ "Stray connection refused to leave " "after gc.collect(): %s" % err
+ )
elif _STRAY_CONNECTION_FAILURES > 10:
assert False, "Encountered more than 10 stray connections"
_STRAY_CONNECTION_FAILURES = 0
@@ -263,14 +271,16 @@ def not_in_(a, b, msg=None):
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
+ a,
+ fragment,
+ )
def eq_ignore_whitespace(a, b, msg=None):
- a = re.sub(r'^\s+?|\n', "", a)
- a = re.sub(r' {2,}', " ", a)
- b = re.sub(r'^\s+?|\n', "", b)
- b = re.sub(r' {2,}', " ", b)
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
@@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls as e:
- assert re.search(
- msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
- print(util.text_type(e).encode('utf-8'))
+ assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
+ msg,
+ e,
+ )
+ print(util.text_type(e).encode("utf-8"))
+
class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None,
- checkparams=None, dialect=None,
- checkpositional=None,
- check_prefetch=None,
- use_default_dialect=False,
- allow_dialect_select=False,
- literal_binds=False,
- schema_translate_map=None):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ literal_binds=False,
+ schema_translate_map=None,
+ ):
if use_default_dialect:
dialect = default.DefaultDialect()
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
- dialect = getattr(self, '__dialect__', None)
+ dialect = getattr(self, "__dialect__", None)
if dialect is None:
dialect = config.db.dialect
- elif dialect == 'default':
+ elif dialect == "default":
dialect = default.DefaultDialect()
- elif dialect == 'default_enhanced':
+ elif dialect == "default_enhanced":
dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
@@ -325,13 +344,13 @@ class AssertsCompiledSQL(object):
compile_kwargs = {}
if schema_translate_map:
- kw['schema_translate_map'] = schema_translate_map
+ kw["schema_translate_map"] = schema_translate_map
if params is not None:
- kw['column_keys'] = list(params)
+ kw["column_keys"] = list(params)
if literal_binds:
- compile_kwargs['literal_binds'] = True
+ compile_kwargs["literal_binds"] = True
if isinstance(clause, orm.Query):
context = clause._compile_context()
@@ -343,25 +362,27 @@ class AssertsCompiledSQL(object):
clause = stmt_mock.mock_calls[0][1][0]
if compile_kwargs:
- kw['compile_kwargs'] = compile_kwargs
+ kw["compile_kwargs"] = compile_kwargs
c = clause.compile(dialect=dialect, **kw)
- param_str = repr(getattr(c, 'params', {}))
+ param_str = repr(getattr(c, "params", {}))
if util.py3k:
- param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
print(
- ("\nSQL String:\n" +
- util.text_type(c) +
- param_str).encode('utf-8'))
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
else:
print(
- "\nSQL String:\n" +
- util.text_type(c).encode('utf-8') +
- param_str)
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
- cc = re.sub(r'[\n\t]', '', util.text_type(c))
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
@@ -375,7 +396,6 @@ class AssertsCompiledSQL(object):
class ComparesTables(object):
-
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
@@ -386,8 +406,10 @@ class ComparesTables(object):
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
- assert isinstance(reflected_c.type, type(c.type)), \
- msg % (reflected_c.type, c.type)
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
else:
self.assert_types_base(reflected_c, c)
@@ -396,20 +418,22 @@ class ComparesTables(object):
eq_(
{f.column.name for f in c.foreign_keys},
- {f.column.name for f in reflected_c.foreign_keys}
+ {f.column.name for f in reflected_c.foreign_keys},
)
if c.server_default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
- assert c1.type._compare_type_affinity(c2.type),\
- "On column %r, type '%s' doesn't correspond to type '%s'" % \
- (c1.name, c1.type, c2.type)
+ assert c1.type._compare_type_affinity(c2.type), (
+ "On column %r, type '%s' doesn't correspond to type '%s'"
+ % (c1.name, c1.type, c2.type)
+ )
class AssertsExecutionResults(object):
@@ -419,15 +443,19 @@ class AssertsExecutionResults(object):
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
+ self.assert_(
+ len(result) == len(list),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
@@ -435,9 +463,11 @@ class AssertsExecutionResults(object):
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
@@ -453,14 +483,19 @@ class AssertsExecutionResults(object):
found = util.IdentitySet(result)
expected = {immutabledict(e) for e in expected}
- for wrong in util.itertools_filterfalse(lambda o:
- isinstance(o, cls), found):
- fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
if len(found) != len(expected):
- fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
NOVALUE = object()
@@ -469,7 +504,8 @@ class AssertsExecutionResults(object):
if isinstance(value, tuple):
try:
self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
+ getattr(obj, key), value[0], *value[1]
+ )
except AssertionError:
return False
else:
@@ -484,8 +520,9 @@ class AssertsExecutionResults(object):
break
else:
fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
return True
def sql_execution_asserter(self, db=None):
@@ -505,9 +542,9 @@ class AssertsExecutionResults(object):
newrules = []
for rule in rules:
if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.CompiledSQL(k, v) for k, v in rule.items()
- ])
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
@@ -516,7 +553,8 @@ class AssertsExecutionResults(object):
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
- db, callable_, assertsql.CountStatements(count))
+ db, callable_, assertsql.CountStatements(count)
+ )
def assert_multiple_sql_count(self, dbs, callable_, counts):
recs = [
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 7a525589d..d8e924cb6 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -26,8 +26,10 @@ class AssertRule(object):
pass
def no_more_statements(self):
- assert False, 'All statements are complete, but pending '\
- 'assertion rules remain'
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
class SQLMatchRule(AssertRule):
@@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule):
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters):
- self.errormessage = \
- "Testing for exact SQL %s parameters %s received %s %s" % (
- self.statement, self.params,
- stmt.statement, stmt.parameters
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
)
+ )
else:
execute_observed.statements.pop(0)
self.is_consumed = True
@@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
-
- def __init__(self, statement, params=None, dialect='default'):
+ def __init__(self, statement, params=None, dialect="default"):
self.statement = statement
self.params = params
self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- if self.dialect == 'default':
+ if self.dialect == "default":
return DefaultDialect()
else:
# ugh
- if self.dialect == 'postgresql':
- params = {'implicit_returning': True}
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
@@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
+ )
else:
- compiled = (
- context.compiled.statement.compile(
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- inline=context.compiled.inline,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ inline=context.compiled.inline,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
)
- _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
- compiled.construct_params(m) for m in parameters]
+ compiled.construct_params(m) for m in parameters
+ ]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
- _received_statement, _received_parameters = \
- self._received_statement(execute_observed)
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
@@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule):
for param_key in param:
# a key in param did not match current
# 'received'
- if param_key not in received or \
- received[param_key] != param[param_key]:
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
break
else:
# all keys in param matched 'received';
@@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule):
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
- 'received_statement': _received_statement,
- 'received_parameters': _received_parameters
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
}
def _all_params(self, context):
@@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule):
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement %r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.statement.replace('%', '%%'), expected_params
- )
+ "Testing for compiled statement %r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (self.statement.replace("%", "%%"), expected_params)
)
@@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL):
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
- self.dialect = 'default'
+ self.dialect = "default"
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement ~%r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.orig_regex, expected_params
- )
+ "Testing for compiled statement ~%r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r" % (self.orig_regex, expected_params)
)
def _compare_sql(self, execute_observed, received_statement):
@@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r'[\n\t]', '', real_stmt)
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
- received_stmt, received_params = super(DialectSQL, self).\
- _received_statement(execute_observed)
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
@@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL):
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt)
+ "statements actually invoked" % received_stmt
+ )
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
- if paramstyle == 'pyformat':
- stmt = re.sub(
- r':([\w_]+)', r"%(\1)s", stmt)
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
# positional params
repl = None
- if paramstyle == 'qmark':
+ if paramstyle == "qmark":
repl = "?"
- elif paramstyle == 'format':
+ elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == 'numeric':
+ elif paramstyle == "numeric":
repl = None
- stmt = re.sub(r':([\w_]+)', repl, stmt)
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
return received_statement == stmt
class CountStatements(AssertRule):
-
def __init__(self, count):
self.count = count
self._statement_count = 0
@@ -256,12 +264,13 @@ class CountStatements(AssertRule):
def no_more_statements(self):
if self.count != self._statement_count:
- assert False, 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
class AllOf(AssertRule):
-
def __init__(self, *rules):
self.rules = set(rules)
@@ -283,7 +292,6 @@ class AllOf(AssertRule):
class EachOf(AssertRule):
-
def __init__(self, *rules):
self.rules = list(rules)
@@ -309,7 +317,6 @@ class EachOf(AssertRule):
class Or(AllOf):
-
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
@@ -331,7 +338,8 @@ class SQLExecuteObserved(object):
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"])
+ ["statement", "parameters", "context", "executemany"],
+ )
):
pass
@@ -374,21 +382,25 @@ def assert_engine(engine):
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(conn, cursor, statement, parameters,
- context, executemany):
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
- if asserter.accumulated and \
- asserter.accumulated[-1].context is context:
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
- statement, parameters, context, executemany)
+ statement, parameters, context, executemany
+ )
)
try:
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index e9cfb3de9..1ff282af5 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -64,8 +64,9 @@ class Config(object):
assert _current, "Can't push without a default Config set up"
cls.push(
Config(
- db, _current.db_opts, _current.options, _current.file_config),
- namespace
+ db, _current.db_opts, _current.options, _current.file_config
+ ),
+ namespace,
)
@classmethod
@@ -94,4 +95,3 @@ class Config(object):
def skip_test(msg):
raise _skip_test_exception(msg)
-
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
index d17e30edf..074e3b338 100644
--- a/lib/sqlalchemy/testing/engines.py
+++ b/lib/sqlalchemy/testing/engines.py
@@ -16,7 +16,6 @@ import warnings
class ConnectionKiller(object):
-
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
@@ -39,8 +38,8 @@ class ConnectionKiller(object):
fn()
except Exception as e:
warnings.warn(
- "testing_reaper couldn't "
- "rollback/close connection: %s" % e)
+ "testing_reaper couldn't " "rollback/close connection: %s" % e
+ )
def rollback_all(self):
for rec in list(self.proxy_refs):
@@ -97,18 +96,19 @@ class ConnectionKiller(object):
if rec.is_valid:
assert False
+
testing_reaper = ConnectionKiller()
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
- if hasattr(bind, 'close'):
+ if hasattr(bind, "close"):
bind.close()
if not config.db.dialect.supports_alter:
from . import assertions
- with assertions.expect_warnings(
- "Can't sort tables", assert_=False):
+
+ with assertions.expect_warnings("Can't sort tables", assert_=False):
metadata.drop_all(bind)
else:
metadata.drop_all(bind)
@@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw):
def all_dialects(exclude=None):
import sqlalchemy.databases as d
+
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
- mod = getattr(__import__(
- 'sqlalchemy.databases.%s' % name).databases, name)
+ mod = getattr(
+ __import__("sqlalchemy.databases.%s" % name).databases, name
+ )
yield mod.dialect()
class ReconnectFixture(object):
-
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
@@ -191,8 +192,8 @@ class ReconnectFixture(object):
fn()
except Exception as e:
warnings.warn(
- "ReconnectFixture couldn't "
- "close connection: %s" % e)
+ "ReconnectFixture couldn't " "close connection: %s" % e
+ )
def shutdown(self, stop=False):
# TODO: this doesn't cover all cases
@@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None):
dbapi = config.db.dialect.dbapi
if not options:
options = {}
- options['module'] = ReconnectFixture(dbapi)
+ options["module"] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
@@ -238,7 +239,7 @@ def testing_engine(url=None, options=None):
if not options:
use_reaper = True
else:
- use_reaper = options.pop('use_reaper', True)
+ use_reaper = options.pop("use_reaper", True)
url = url or config.db.url
@@ -253,15 +254,15 @@ def testing_engine(url=None, options=None):
default_opt.update(options)
engine = create_engine(url, **options)
- engine._has_events = True # enable event blocks, helps with profiling
+ engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
if use_reaper:
- event.listen(engine.pool, 'connect', testing_reaper.connect)
- event.listen(engine.pool, 'checkout', testing_reaper.checkout)
- event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
+ event.listen(engine.pool, "connect", testing_reaper.connect)
+ event.listen(engine.pool, "checkout", testing_reaper.checkout)
+ event.listen(engine.pool, "invalidate", testing_reaper.invalidate)
testing_reaper.add_engine(engine)
return engine
@@ -290,19 +291,17 @@ def mock_engine(dialect_name=None):
buffer.append(sql)
def assert_sql(stmts):
- recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
assert recv == stmts, recv
def print_sql():
d = engine.dialect
- return "\n".join(
- str(s.compile(dialect=d))
- for s in engine.mock
- )
-
- engine = create_engine(dialect_name + '://',
- strategy='mock', executor=executor)
- assert not hasattr(engine, 'mock')
+ return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
+
+ engine = create_engine(
+ dialect_name + "://", strategy="mock", executor=executor
+ )
+ assert not hasattr(engine, "mock")
engine.mock = buffer
engine.assert_sql = assert_sql
engine.print_sql = print_sql
@@ -358,14 +357,15 @@ class DBAPIProxyConnection(object):
return getattr(self.conn, key)
-def proxying_engine(conn_cls=DBAPIProxyConnection,
- cursor_cls=DBAPIProxyCursor):
+def proxying_engine(
+ conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
+):
"""Produce an engine that provides proxy hooks for
common methods.
"""
+
def mock_conn():
return conn_cls(config.db, cursor_cls)
- return testing_engine(options={'creator': mock_conn})
-
+ return testing_engine(options={"creator": mock_conn})
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
index b634735fc..42c42149c 100644
--- a/lib/sqlalchemy/testing/entities.py
+++ b/lib/sqlalchemy/testing/entities.py
@@ -12,7 +12,6 @@ _repr_stack = set()
class BasicEntity(object):
-
def __init__(self, **kw):
for key, value in kw.items():
setattr(self, key, value)
@@ -24,17 +23,22 @@ class BasicEntity(object):
try:
return "%s(%s)" % (
(self.__class__.__name__),
- ', '.join(["%s=%r" % (key, getattr(self, key))
- for key in sorted(self.__dict__.keys())
- if not key.startswith('_')]))
+ ", ".join(
+ [
+ "%s=%r" % (key, getattr(self, key))
+ for key in sorted(self.__dict__.keys())
+ if not key.startswith("_")
+ ]
+ ),
+ )
finally:
_repr_stack.remove(id(self))
+
_recursion_stack = set()
class ComparableEntity(BasicEntity):
-
def __hash__(self):
return hash(self.__class__)
@@ -75,7 +79,7 @@ class ComparableEntity(BasicEntity):
b = other
for attr in list(a.__dict__):
- if attr.startswith('_'):
+ if attr.startswith("_"):
continue
value = getattr(a, attr)
@@ -85,9 +89,10 @@ class ComparableEntity(BasicEntity):
except (AttributeError, sa_exc.UnboundExecutionError):
return False
- if hasattr(value, '__iter__'):
- if hasattr(value, '__getitem__') and not hasattr(
- value, 'keys'):
+ if hasattr(value, "__iter__"):
+ if hasattr(value, "__getitem__") and not hasattr(
+ value, "keys"
+ ):
if list(value) != list(battr):
return False
else:
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
index 512fffb3b..9ed9e42c3 100644
--- a/lib/sqlalchemy/testing/exclusions.py
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -16,6 +16,7 @@ from . import config
from .. import util
from ..util import decorator
+
def skip_if(predicate, reason=None):
rule = compound()
pred = _as_predicate(predicate, reason)
@@ -70,15 +71,15 @@ class compound(object):
def matching_config_reasons(self, config):
return [
- predicate._as_string(config) for predicate
- in self.skips.union(self.fails)
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
if predicate(config)
]
def include_test(self, include_tags, exclude_tags):
return bool(
- not self.tags.intersection(exclude_tags) and
- (not include_tags or self.tags.intersection(include_tags))
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
)
def _extend(self, other):
@@ -87,13 +88,14 @@ class compound(object):
self.tags.update(other.tags)
def __call__(self, fn):
- if hasattr(fn, '_sa_exclusion_extend'):
+ if hasattr(fn, "_sa_exclusion_extend"):
fn._sa_exclusion_extend._extend(self)
return fn
@decorator
def decorate(fn, *args, **kw):
return self._do(config._current, fn, *args, **kw)
+
decorated = decorate(fn)
decorated._sa_exclusion_extend = self
return decorated
@@ -113,10 +115,7 @@ class compound(object):
def _do(self, cfg, fn, *args, **kw):
for skip in self.skips:
if skip(cfg):
- msg = "'%s' : %s" % (
- fn.__name__,
- skip._as_string(cfg)
- )
+ msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg))
config.skip_test(msg)
try:
@@ -127,16 +126,20 @@ class compound(object):
self._expect_success(cfg, name=fn.__name__)
return return_value
- def _expect_failure(self, config, ex, name='block'):
+ def _expect_failure(self, config, ex, name="block"):
for fail in self.fails:
if fail(config):
- print(("%s failed as expected (%s): %s " % (
- name, fail._as_string(config), str(ex))))
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str(ex))
+ )
+ )
break
else:
util.raise_from_cause(ex)
- def _expect_success(self, config, name='block'):
+ def _expect_success(self, config, name="block"):
if not self.fails:
return
for fail in self.fails:
@@ -144,13 +147,12 @@ class compound(object):
break
else:
raise AssertionError(
- "Unexpected success for '%s' (%s)" %
- (
+ "Unexpected success for '%s' (%s)"
+ % (
name,
" and ".join(
- fail._as_string(config)
- for fail in self.fails
- )
+ fail._as_string(config) for fail in self.fails
+ ),
)
)
@@ -186,21 +188,24 @@ class Predicate(object):
return predicate
elif isinstance(predicate, (list, set)):
return OrPredicate(
- [cls.as_predicate(pred) for pred in predicate],
- description)
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, util.string_types):
tokens = re.match(
- r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate)
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
if not tokens:
raise ValueError(
- "Couldn't locate DB name in predicate: %r" % predicate)
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
db = tokens.group(1)
op = tokens.group(2)
spec = (
tuple(int(d) for d in tokens.group(3).split("."))
- if tokens.group(3) else None
+ if tokens.group(3)
+ else None
)
return SpecPredicate(db, op, spec, description=description)
@@ -215,11 +220,13 @@ class Predicate(object):
bool_ = not negate
return self.description % {
"driver": config.db.url.get_driver_name()
- if config else "<no driver>",
+ if config
+ else "<no driver>",
"database": config.db.url.get_backend_name()
- if config else "<no database>",
+ if config
+ else "<no database>",
"doesnt_support": "doesn't support" if bool_ else "does support",
- "does_support": "does support" if bool_ else "doesn't support"
+ "does_support": "does support" if bool_ else "doesn't support",
}
def _as_string(self, config=None, negate=False):
@@ -246,21 +253,21 @@ class SpecPredicate(Predicate):
self.description = description
_ops = {
- '<': operator.lt,
- '>': operator.gt,
- '==': operator.eq,
- '!=': operator.ne,
- '<=': operator.le,
- '>=': operator.ge,
- 'in': operator.contains,
- 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, config):
engine = config.db
if "+" in self.db:
- dialect, driver = self.db.split('+')
+ dialect, driver = self.db.split("+")
else:
dialect, driver = self.db, None
@@ -273,8 +280,9 @@ class SpecPredicate(Predicate):
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
- oper = hasattr(self.op, '__call__') and self.op \
- or self._ops[self.op]
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
return oper(version, self.spec)
else:
return True
@@ -289,17 +297,9 @@ class SpecPredicate(Predicate):
return "%s" % self.db
else:
if negate:
- return "not %s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "not %s %s %s" % (self.db, self.op, self.spec)
else:
- return "%s %s %s" % (
- self.db,
- self.op,
- self.spec
- )
+ return "%s %s %s" % (self.db, self.op, self.spec)
class LambdaPredicate(Predicate):
@@ -356,8 +356,9 @@ class OrPredicate(Predicate):
conjunction = " and "
else:
conjunction = " or "
- return conjunction.join(p._as_string(config, negate=negate)
- for p in self.predicates)
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
def _negation_str(self, config):
if self.description is not None:
@@ -387,7 +388,7 @@ def _server_version(engine):
# force metadata to be retrieved
conn = engine.connect()
- version = getattr(engine.dialect, 'server_version_info', None)
+ version = getattr(engine.dialect, "server_version_info", None)
if version is None:
version = ()
conn.close()
@@ -395,9 +396,7 @@ def _server_version(engine):
def db_spec(*dbs):
- return OrPredicate(
- [Predicate.as_predicate(db) for db in dbs]
- )
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
def open():
@@ -422,11 +421,7 @@ def fails_on(db, reason=None):
def fails_on_everything_except(*dbs):
- return succeeds_if(
- OrPredicate([
- Predicate.as_predicate(db) for db in dbs
- ])
- )
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
def skip(db, reason=None):
@@ -435,8 +430,9 @@ def skip(db, reason=None):
def only_on(dbs, reason=None):
return only_if(
- OrPredicate([Predicate.as_predicate(db, reason)
- for db in util.to_list(dbs)])
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
)
@@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None):
def against(config, *queries):
assert queries, "no queries sent!"
- return OrPredicate([
- Predicate.as_predicate(query)
- for query in queries
- ])(config)
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index dd0fa5a48..98184cdd4 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -54,19 +54,19 @@ class TestBase(object):
class TablesTest(TestBase):
# 'once', None
- run_setup_bind = 'once'
+ run_setup_bind = "once"
# 'once', 'each', None
- run_define_tables = 'once'
+ run_define_tables = "once"
# 'once', 'each', None
- run_create_tables = 'once'
+ run_create_tables = "once"
# 'once', 'each', None
- run_inserts = 'each'
+ run_inserts = "each"
# 'each', None
- run_deletes = 'each'
+ run_deletes = "each"
# 'once', None
run_dispose_bind = None
@@ -86,10 +86,10 @@ class TablesTest(TestBase):
@classmethod
def _init_class(cls):
- if cls.run_define_tables == 'each':
- if cls.run_create_tables == 'once':
- cls.run_create_tables = 'each'
- assert cls.run_inserts in ('each', None)
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
cls.other = adict()
cls.tables = adict()
@@ -100,40 +100,40 @@ class TablesTest(TestBase):
@classmethod
def _setup_once_inserts(cls):
- if cls.run_inserts == 'once':
+ if cls.run_inserts == "once":
cls._load_fixtures()
cls.insert_data()
@classmethod
def _setup_once_tables(cls):
- if cls.run_define_tables == 'once':
+ if cls.run_define_tables == "once":
cls.define_tables(cls.metadata)
- if cls.run_create_tables == 'once':
+ if cls.run_create_tables == "once":
cls.metadata.create_all(cls.bind)
cls.tables.update(cls.metadata.tables)
def _setup_each_tables(self):
- if self.run_define_tables == 'each':
+ if self.run_define_tables == "each":
self.tables.clear()
- if self.run_create_tables == 'each':
+ if self.run_create_tables == "each":
drop_all_tables(self.metadata, self.bind)
self.metadata.clear()
self.define_tables(self.metadata)
- if self.run_create_tables == 'each':
+ if self.run_create_tables == "each":
self.metadata.create_all(self.bind)
self.tables.update(self.metadata.tables)
- elif self.run_create_tables == 'each':
+ elif self.run_create_tables == "each":
drop_all_tables(self.metadata, self.bind)
self.metadata.create_all(self.bind)
def _setup_each_inserts(self):
- if self.run_inserts == 'each':
+ if self.run_inserts == "each":
self._load_fixtures()
self.insert_data()
def _teardown_each_tables(self):
# no need to run deletes if tables are recreated on setup
- if self.run_define_tables != 'each' and self.run_deletes == 'each':
+ if self.run_define_tables != "each" and self.run_deletes == "each":
with self.bind.connect() as conn:
for table in reversed(self.metadata.sorted_tables):
try:
@@ -141,7 +141,8 @@ class TablesTest(TestBase):
except sa.exc.DBAPIError as ex:
util.print_(
("Error emptying table %s: %r" % (table, ex)),
- file=sys.stderr)
+ file=sys.stderr,
+ )
def setup(self):
self._setup_each_tables()
@@ -155,7 +156,7 @@ class TablesTest(TestBase):
if cls.run_create_tables:
drop_all_tables(cls.metadata, cls.bind)
- if cls.run_dispose_bind == 'once':
+ if cls.run_dispose_bind == "once":
cls.dispose_bind(cls.bind)
cls.metadata.bind = None
@@ -173,9 +174,9 @@ class TablesTest(TestBase):
@classmethod
def dispose_bind(cls, bind):
- if hasattr(bind, 'dispose'):
+ if hasattr(bind, "dispose"):
bind.dispose()
- elif hasattr(bind, 'close'):
+ elif hasattr(bind, "close"):
bind.close()
@classmethod
@@ -212,8 +213,12 @@ class TablesTest(TestBase):
continue
cls.bind.execute(
table.insert(),
- [dict(zip(headers[table], column_values))
- for column_values in rows[table]])
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
from sqlalchemy import event
@@ -236,7 +241,6 @@ class RemovesEvents(object):
class _ORMTest(object):
-
@classmethod
def teardown_class(cls):
sa.orm.session.Session.close_all()
@@ -249,10 +253,10 @@ class ORMTest(_ORMTest, TestBase):
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
- run_setup_classes = 'once'
+ run_setup_classes = "once"
# 'once', 'each', None
- run_setup_mappers = 'each'
+ run_setup_mappers = "each"
classes = None
@@ -292,20 +296,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
@classmethod
def _setup_once_classes(cls):
- if cls.run_setup_classes == 'once':
+ if cls.run_setup_classes == "once":
cls._with_register_classes(cls.setup_classes)
@classmethod
def _setup_once_mappers(cls):
- if cls.run_setup_mappers == 'once':
+ if cls.run_setup_mappers == "once":
cls._with_register_classes(cls.setup_mappers)
def _setup_each_mappers(self):
- if self.run_setup_mappers == 'each':
+ if self.run_setup_mappers == "each":
self._with_register_classes(self.setup_mappers)
def _setup_each_classes(self):
- if self.run_setup_classes == 'each':
+ if self.run_setup_classes == "each":
self._with_register_classes(self.setup_classes)
@classmethod
@@ -339,11 +343,11 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# some tests create mappers in the test bodies
# and will define setup_mappers as None -
# clear mappers in any case
- if self.run_setup_mappers != 'once':
+ if self.run_setup_mappers != "once":
sa.orm.clear_mappers()
def _teardown_each_classes(self):
- if self.run_setup_classes != 'once':
+ if self.run_setup_classes != "once":
self.classes.clear()
@classmethod
@@ -356,8 +360,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
class DeclarativeMappedTest(MappedTest):
- run_setup_classes = 'once'
- run_setup_mappers = 'once'
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
@classmethod
def _setup_once_tables(cls):
@@ -370,15 +374,16 @@ class DeclarativeMappedTest(MappedTest):
class FindFixtureDeclarative(DeclarativeMeta):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
- return DeclarativeMeta.__init__(
- cls, classname, bases, dict_)
+ return DeclarativeMeta.__init__(cls, classname, bases, dict_)
class DeclarativeBasic(object):
__table_cls__ = schema.Table
- _DeclBase = declarative_base(metadata=cls.metadata,
- metaclass=FindFixtureDeclarative,
- cls=DeclarativeBasic)
+ _DeclBase = declarative_base(
+ metadata=cls.metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic,
+ )
cls.DeclarativeBasic = _DeclBase
fn()
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
index ea0a8da82..dc530af5e 100644
--- a/lib/sqlalchemy/testing/mock.py
+++ b/lib/sqlalchemy/testing/mock.py
@@ -18,4 +18,5 @@ else:
except ImportError:
raise ImportError(
"SQLAlchemy's test suite requires the "
- "'mock' library as of 0.8.2.")
+ "'mock' library as of 0.8.2."
+ )
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
index 087fc1fe6..e84cbde44 100644
--- a/lib/sqlalchemy/testing/pickleable.py
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -46,29 +46,28 @@ class Parent(fixtures.ComparableEntity):
class Screen(object):
-
def __init__(self, obj, parent=None):
self.obj = obj
self.parent = parent
class Foo(object):
-
def __init__(self, moredata):
- self.data = 'im data'
- self.stuff = 'im stuff'
+ self.data = "im data"
+ self.stuff = "im stuff"
self.moredata = moredata
__hash__ = object.__hash__
def __eq__(self, other):
- return other.data == self.data and \
- other.stuff == self.stuff and \
- other.moredata == self.moredata
+ return (
+ other.data == self.data
+ and other.stuff == self.stuff
+ and other.moredata == self.moredata
+ )
class Bar(object):
-
def __init__(self, x, y):
self.x = x
self.y = y
@@ -76,35 +75,36 @@ class Bar(object):
__hash__ = object.__hash__
def __eq__(self, other):
- return other.__class__ is self.__class__ and \
- other.x == self.x and \
- other.y == self.y
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class OldSchool:
-
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
- return other.__class__ is self.__class__ and \
- other.x == self.x and \
- other.y == self.y
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
class OldSchoolWithoutCompare:
-
def __init__(self, x, y):
self.x = x
self.y = y
class BarWithoutCompare(object):
-
def __init__(self, x, y):
self.x = x
self.y = y
@@ -114,7 +114,6 @@ class BarWithoutCompare(object):
class NotComparable(object):
-
def __init__(self, data):
self.data = data
@@ -129,7 +128,6 @@ class NotComparable(object):
class BrokenComparable(object):
-
def __init__(self, data):
self.data = data
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
index 497fcb7e5..bb52c125c 100644
--- a/lib/sqlalchemy/testing/plugin/bootstrap.py
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0.
import os
import sys
-bootstrap_file = locals()['bootstrap_file']
-to_bootstrap = locals()['to_bootstrap']
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
def load_file_as_module(name):
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
if sys.version_info >= (3, 3):
from importlib import machinery
+
mod = machinery.SourceFileLoader(name, path).load_module()
else:
import imp
+
mod = imp.load_source(name, path)
return mod
+
if to_bootstrap == "pytest":
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py
index 20ea61d89..0c28a5213 100644
--- a/lib/sqlalchemy/testing/plugin/noseplugin.py
+++ b/lib/sqlalchemy/testing/plugin/noseplugin.py
@@ -25,6 +25,7 @@ import sys
from nose.plugins import Plugin
import nose
+
fixtures = None
py3k = sys.version_info >= (3, 0)
@@ -33,7 +34,7 @@ py3k = sys.version_info >= (3, 0)
class NoseSQLAlchemy(Plugin):
enabled = True
- name = 'sqla_testing'
+ name = "sqla_testing"
score = 100
def options(self, parser, env=os.environ):
@@ -41,10 +42,14 @@ class NoseSQLAlchemy(Plugin):
opt = parser.add_option
def make_option(name, **kw):
- callback_ = kw.pop("callback", None) or kw.pop("zeroarg_callback", None)
+ callback_ = kw.pop("callback", None) or kw.pop(
+ "zeroarg_callback", None
+ )
if callback_:
+
def wrap_(option, opt_str, value, parser):
callback_(opt_str, value, parser)
+
kw["callback"] = wrap_
opt(name, **kw)
@@ -73,7 +78,7 @@ class NoseSQLAlchemy(Plugin):
def wantMethod(self, fn):
if py3k:
- if not hasattr(fn.__self__, 'cls'):
+ if not hasattr(fn.__self__, "cls"):
return False
cls = fn.__self__.cls
else:
@@ -84,24 +89,24 @@ class NoseSQLAlchemy(Plugin):
return plugin_base.want_class(cls)
def beforeTest(self, test):
- if not hasattr(test.test, 'cls'):
+ if not hasattr(test.test, "cls"):
return
plugin_base.before_test(
test,
test.test.cls.__module__,
- test.test.cls, test.test.method.__name__)
+ test.test.cls,
+ test.test.method.__name__,
+ )
def afterTest(self, test):
plugin_base.after_test(test)
def startContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.start_test_class(ctx)
def stopContext(self, ctx):
- if not isinstance(ctx, type) \
- or not issubclass(ctx, fixtures.TestBase):
+ if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase):
return
plugin_base.stop_test_class(ctx)
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 0ffcae093..5d6bf2975 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -46,58 +46,130 @@ options = None
def setup_options(make_option):
- make_option("--log-info", action="callback", type="string", callback=_log,
- help="turn on info logging for <LOG> (multiple OK)")
- make_option("--log-debug", action="callback",
- type="string", callback=_log,
- help="turn on debug logging for <LOG> (multiple OK)")
- make_option("--db", action="append", type="string", dest="db",
- help="Use prefab database uri. Multiple OK, "
- "first one is run by default.")
- make_option('--dbs', action='callback', zeroarg_callback=_list_dbs,
- help="List available prefab dbs")
- make_option("--dburi", action="append", type="string", dest="dburi",
- help="Database uri. Multiple OK, "
- "first one is run by default.")
- make_option("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first")
- make_option("--backend-only", action="store_true", dest="backend_only",
- help="Run only tests marked with __backend__")
- make_option("--nomemory", action="store_true", dest="nomemory",
- help="Don't run memory profiling tests")
- make_option("--postgresql-templatedb", type="string",
- help="name of template database to use for PostgreSQL "
- "CREATE DATABASE (defaults to current database)")
- make_option("--low-connections", action="store_true",
- dest="low_connections",
- help="Use a low number of distinct connections - "
- "i.e. for Oracle TNS")
- make_option("--write-idents", type="string", dest="write_idents",
- help="write out generated follower idents to <file>, "
- "when -n<num> is used")
- make_option("--reversetop", action="store_true",
- dest="reversetop", default=False,
- help="Use a random-ordering set implementation in the ORM "
- "(helps reveal dependency issues)")
- make_option("--requirements", action="callback", type="string",
- callback=_requirements_opt,
- help="requirements class for testing, overrides setup.cfg")
- make_option("--with-cdecimal", action="store_true",
- dest="cdecimal", default=False,
- help="Monkeypatch the cdecimal library into Python 'decimal' "
- "for all tests")
- make_option("--include-tag", action="callback", callback=_include_tag,
- type="string",
- help="Include tests with tag <tag>")
- make_option("--exclude-tag", action="callback", callback=_exclude_tag,
- type="string",
- help="Exclude tests with tag <tag>")
- make_option("--write-profiles", action="store_true",
- dest="write_profiles", default=False,
- help="Write/update failing profiling data.")
- make_option("--force-write-profiles", action="store_true",
- dest="force_write_profiles", default=False,
- help="Unconditionally write/update profiling data.")
+ make_option(
+ "--log-info",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type="string",
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type="string",
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type="string",
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__",
+ )
+ make_option(
+ "--nomemory",
+ action="store_true",
+ dest="nomemory",
+ help="Don't run memory profiling tests",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type="string",
+ help="name of template database to use for PostgreSQL "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type="string",
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type="string",
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type="string",
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type="string",
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--write-profiles",
+ action="store_true",
+ dest="write_profiles",
+ default=False,
+ help="Write/update failing profiling data.",
+ )
+ make_option(
+ "--force-write-profiles",
+ action="store_true",
+ dest="force_write_profiles",
+ default=False,
+ help="Unconditionally write/update profiling data.",
+ )
def configure_follower(follower_ident):
@@ -108,6 +180,7 @@ def configure_follower(follower_ident):
"""
from sqlalchemy.testing import provision
+
provision.FOLLOWER_IDENT = follower_ident
@@ -121,9 +194,9 @@ def memoize_important_follower_config(dict_):
callables, so we have to just copy all of that over.
"""
- dict_['memoized_config'] = {
- 'include_tags': include_tags,
- 'exclude_tags': exclude_tags
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
}
@@ -134,14 +207,14 @@ def restore_important_follower_config(dict_):
"""
global include_tags, exclude_tags
- include_tags.update(dict_['memoized_config']['include_tags'])
- exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
def read_config():
global file_config
file_config = configparser.ConfigParser()
- file_config.read(['setup.cfg', 'test.cfg'])
+ file_config.read(["setup.cfg", "test.cfg"])
def pre_begin(opt):
@@ -155,6 +228,7 @@ def pre_begin(opt):
def set_coverage_flag(value):
options.has_coverage = value
+
_skip_test_exception = None
@@ -171,34 +245,33 @@ def post_begin():
# late imports, has to happen after config as well
# as nose plugins like coverage
- global util, fixtures, engines, exclusions, \
- assertions, warnings, profiling,\
- config, testing
- from sqlalchemy import testing # noqa
+ global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing
+ from sqlalchemy import testing # noqa
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
- from sqlalchemy.testing import assertions, warnings, profiling # noqa
+ from sqlalchemy.testing import assertions, warnings, profiling # noqa
from sqlalchemy.testing import config # noqa
from sqlalchemy import util # noqa
- warnings.setup_filters()
+ warnings.setup_filters()
def _log(opt_str, value, parser):
global logging
if not logging:
import logging
+
logging.basicConfig()
- if opt_str.endswith('-info'):
+ if opt_str.endswith("-info"):
logging.getLogger(value).setLevel(logging.INFO)
- elif opt_str.endswith('-debug'):
+ elif opt_str.endswith("-debug"):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
print("Available --db options (use --dburi to override)")
- for macro in sorted(file_config.options('db')):
- print("%20s\t%s" % (macro, file_config.get('db', macro)))
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
sys.exit(0)
@@ -207,11 +280,12 @@ def _requirements_opt(opt_str, value, parser):
def _exclude_tag(opt_str, value, parser):
- exclude_tags.add(value.replace('-', '_'))
+ exclude_tags.add(value.replace("-", "_"))
def _include_tag(opt_str, value, parser):
- include_tags.add(value.replace('-', '_'))
+ include_tags.add(value.replace("-", "_"))
+
pre_configure = []
post_configure = []
@@ -243,7 +317,8 @@ def _set_nomemory(opt, file_config):
def _monkeypatch_cdecimal(options, file_config):
if options.cdecimal:
import cdecimal
- sys.modules['decimal'] = cdecimal
+
+ sys.modules["decimal"] = cdecimal
@post
@@ -266,27 +341,28 @@ def _engine_uri(options, file_config):
if options.db:
for db_token in options.db:
- for db in re.split(r'[,\s]+', db_token):
- if db not in file_config.options('db'):
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
raise RuntimeError(
"Unknown URI specifier '%s'. "
- "Specify --dbs for known uris."
- % db)
+ "Specify --dbs for known uris." % db
+ )
else:
- db_urls.append(file_config.get('db', db))
+ db_urls.append(file_config.get("db", db))
if not db_urls:
- db_urls.append(file_config.get('db', 'default'))
+ db_urls.append(file_config.get("db", "default"))
config._current = None
for db_url in db_urls:
- if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
+ if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
with open(options.write_idents, "a") as file_:
file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
cfg = provision.setup_config(
- db_url, options, file_config, provision.FOLLOWER_IDENT)
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
if not config._current:
cfg.set_as_current(cfg, testing)
@@ -295,7 +371,7 @@ def _engine_uri(options, file_config):
@post
def _requirements(options, file_config):
- requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
@@ -334,22 +410,28 @@ def _prep_testing_database(options, file_config):
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData())
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(vname, schema.MetaData())
+ )
+ )
if config.requirements.schemas.enabled_for_config(cfg):
try:
- view_names = inspector.get_view_names(
- schema="test_schema")
+ view_names = inspector.get_view_names(schema="test_schema")
except NotImplementedError:
pass
else:
for vname in view_names:
- e.execute(schema._DropView(
- schema.Table(vname, schema.MetaData(),
- schema="test_schema")
- ))
+ e.execute(
+ schema._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
util.drop_all_tables(e, inspector)
@@ -358,23 +440,29 @@ def _prep_testing_database(options, file_config):
if against(cfg, "postgresql"):
from sqlalchemy.dialects import postgresql
+
for enum in inspector.get_enums("*"):
- e.execute(postgresql.DropEnumType(
- postgresql.ENUM(
- name=enum['name'],
- schema=enum['schema'])))
+ e.execute(
+ postgresql.DropEnumType(
+ postgresql.ENUM(
+ name=enum["name"], schema=enum["schema"]
+ )
+ )
+ )
@post
def _reverse_topological(options, file_config):
if options.reversetop:
from sqlalchemy.orm.util import randomize_unitofwork
+
randomize_unitofwork()
@post
def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
+
config.options = options
config.file_config = file_config
@@ -382,17 +470,20 @@ def _post_setup_options(opt, file_config):
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
+
profiling._profile_stats = profiling.ProfileStatsFile(
- file_config.get('sqla_testing', 'profile_file'))
+ file_config.get("sqla_testing", "profile_file")
+ )
def want_class(cls):
if not issubclass(cls, fixtures.TestBase):
return False
- elif cls.__name__.startswith('_'):
+ elif cls.__name__.startswith("_"):
return False
- elif config.options.backend_only and not getattr(cls, '__backend__',
- False):
+ elif config.options.backend_only and not getattr(
+ cls, "__backend__", False
+ ):
return False
else:
return True
@@ -405,25 +496,28 @@ def want_method(cls, fn):
return False
elif include_tags:
return (
- hasattr(cls, '__tags__') and
- exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
) or (
- hasattr(fn, '_sa_exclusion_extend') and
- fn._sa_exclusion_extend.include_test(
- include_tags, exclude_tags)
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
)
- elif exclude_tags and hasattr(cls, '__tags__'):
+ elif exclude_tags and hasattr(cls, "__tags__"):
return exclusions.tags(cls.__tags__).include_test(
- include_tags, exclude_tags)
- elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
else:
return True
def generate_sub_tests(cls, module):
- if getattr(cls, '__backend__', False):
+ if getattr(cls, "__backend__", False):
for cfg in _possible_configs_for_cls(cls):
orig_name = cls.__name__
@@ -431,16 +525,13 @@ def generate_sub_tests(cls, module):
# pytest junit plugin, which is tripped up by the brackets
# and periods, so sanitize
- alpha_name = re.sub('[_\[\]\.]+', '_', cfg.name)
- alpha_name = re.sub('_+$', '', alpha_name)
+ alpha_name = re.sub("[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub("_+$", "", alpha_name)
name = "%s_%s" % (cls.__name__, alpha_name)
subcls = type(
name,
- (cls, ),
- {
- "_sa_orig_cls_name": orig_name,
- "__only_on_config__": cfg
- }
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
)
setattr(module, name, subcls)
yield subcls
@@ -454,8 +545,8 @@ def start_test_class(cls):
def stop_test_class(cls):
- #from sqlalchemy import inspect
- #assert not inspect(testing.db).get_table_names()
+ # from sqlalchemy import inspect
+ # assert not inspect(testing.db).get_table_names()
engines.testing_reaper._stop_test_ctx()
try:
if not options.low_connections:
@@ -475,7 +566,7 @@ def final_process_cleanup():
def _setup_engine(cls):
- if getattr(cls, '__engine_options__', None):
+ if getattr(cls, "__engine_options__", None):
eng = engines.testing_engine(options=cls.__engine_options__)
config._current.push_engine(eng, testing)
@@ -485,7 +576,7 @@ def before_test(test, test_module_name, test_class, test_name):
# like a nose id, e.g.:
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
- name = getattr(test_class, '_sa_orig_cls_name', test_class.__name__)
+ name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
@@ -505,16 +596,16 @@ def _possible_configs_for_cls(cls, reasons=None):
if spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on__', None):
+ if getattr(cls, "__only_on__", None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
for config_obj in list(all_configs):
if not spec(config_obj):
all_configs.remove(config_obj)
- if getattr(cls, '__only_on_config__', None):
+ if getattr(cls, "__only_on_config__", None):
all_configs.intersection_update([cls.__only_on_config__])
- if hasattr(cls, '__requires__'):
+ if hasattr(cls, "__requires__"):
requirements = config.requirements
for config_obj in list(all_configs):
for requirement in cls.__requires__:
@@ -527,7 +618,7 @@ def _possible_configs_for_cls(cls, reasons=None):
reasons.extend(skip_reasons)
break
- if hasattr(cls, '__prefer_requires__'):
+ if hasattr(cls, "__prefer_requires__"):
non_preferred = set()
requirements = config.requirements
for config_obj in list(all_configs):
@@ -546,30 +637,32 @@ def _do_skips(cls):
reasons = []
all_configs = _possible_configs_for_cls(cls, reasons)
- if getattr(cls, '__skip_if__', False):
- for c in getattr(cls, '__skip_if__'):
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
if c():
- config.skip_test("'%s' skipped by %s" % (
- cls.__name__, c.__name__)
+ config.skip_test(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
)
if not all_configs:
msg = "'%s' unsupported on any DB implementation %s%s" % (
cls.__name__,
", ".join(
- "'%s(%s)+%s'" % (
+ "'%s(%s)+%s'"
+ % (
config_obj.db.name,
".".join(
- str(dig) for dig in
- exclusions._server_version(config_obj.db)),
- config_obj.db.driver
+ str(dig)
+ for dig in exclusions._server_version(config_obj.db)
+ ),
+ config_obj.db.driver,
)
- for config_obj in config.Config.all_configs()
+ for config_obj in config.Config.all_configs()
),
- ", ".join(reasons)
+ ", ".join(reasons),
)
config.skip_test(msg)
- elif hasattr(cls, '__prefer_backends__'):
+ elif hasattr(cls, "__prefer_backends__"):
non_preferred = set()
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
for config_obj in all_configs:
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index da682ea00..fd0a48462 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -13,6 +13,7 @@ import os
try:
import xdist # noqa
+
has_xdist = True
except ImportError:
has_xdist = False
@@ -24,30 +25,42 @@ def pytest_addoption(parser):
def make_option(name, **kw):
callback_ = kw.pop("callback", None)
if callback_:
+
class CallableAction(argparse.Action):
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
callback_(option_string, values, parser)
+
kw["action"] = CallableAction
zeroarg_callback = kw.pop("zeroarg_callback", None)
if zeroarg_callback:
+
class CallableAction(argparse.Action):
- def __init__(self, option_strings,
- dest, default=False,
- required=False, help=None):
- super(CallableAction, self).__init__(
- option_strings=option_strings,
- dest=dest,
- nargs=0,
- const=True,
- default=default,
- required=required,
- help=help)
-
- def __call__(self, parser, namespace,
- values, option_string=None):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None,
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
zeroarg_callback(option_string, values, parser)
+
kw["action"] = CallableAction
group.addoption(name, **kw)
@@ -59,18 +72,18 @@ def pytest_addoption(parser):
def pytest_configure(config):
if hasattr(config, "slaveinput"):
plugin_base.restore_important_follower_config(config.slaveinput)
- plugin_base.configure_follower(
- config.slaveinput["follower_ident"]
- )
+ plugin_base.configure_follower(config.slaveinput["follower_ident"])
else:
- if config.option.write_idents and \
- os.path.exists(config.option.write_idents):
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
os.remove(config.option.write_idents)
plugin_base.pre_begin(config.option)
- plugin_base.set_coverage_flag(bool(getattr(config.option,
- "cov_source", False)))
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
plugin_base.set_skip_test(pytest.skip.Exception)
@@ -94,10 +107,12 @@ if has_xdist:
node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
from sqlalchemy.testing import provision
+
provision.create_follower_db(node.slaveinput["follower_ident"])
def pytest_testnodedown(node, error):
from sqlalchemy.testing import provision
+
provision.drop_follower_db(node.slaveinput["follower_ident"])
@@ -114,19 +129,22 @@ def pytest_collection_modifyitems(session, config, items):
rebuilt_items = collections.defaultdict(list)
items[:] = [
- item for item in
- items if isinstance(item.parent, pytest.Instance)
- and not item.parent.parent.name.startswith("_")]
+ item
+ for item in items
+ if isinstance(item.parent, pytest.Instance)
+ and not item.parent.parent.name.startswith("_")
+ ]
test_classes = set(item.parent for item in items)
for test_class in test_classes:
for sub_cls in plugin_base.generate_sub_tests(
- test_class.cls, test_class.parent.module):
+ test_class.cls, test_class.parent.module
+ ):
if sub_cls is not test_class.cls:
list_ = rebuilt_items[test_class.cls]
for inst in pytest.Class(
- sub_cls.__name__,
- parent=test_class.parent.parent).collect():
+ sub_cls.__name__, parent=test_class.parent.parent
+ ).collect():
list_.extend(inst.collect())
newitems = []
@@ -139,23 +157,29 @@ def pytest_collection_modifyitems(session, config, items):
# seems like the functions attached to a test class aren't sorted already?
# is that true and why's that? (when using unittest, they're sorted)
- items[:] = sorted(newitems, key=lambda item: (
- item.parent.parent.parent.name,
- item.parent.parent.name,
- item.name
- ))
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.parent.parent.parent.name,
+ item.parent.parent.name,
+ item.name,
+ ),
+ )
def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(obj):
return pytest.Class(name, parent=collector)
- elif inspect.isfunction(obj) and \
- isinstance(collector, pytest.Instance) and \
- plugin_base.want_method(collector.cls, obj):
+ elif (
+ inspect.isfunction(obj)
+ and isinstance(collector, pytest.Instance)
+ and plugin_base.want_method(collector.cls, obj)
+ ):
return pytest.Function(name, parent=collector)
else:
return []
+
_current_class = None
@@ -180,6 +204,7 @@ def pytest_runtest_setup(item):
global _current_class
class_teardown(item.parent.parent)
_current_class = None
+
item.parent.parent.addfinalizer(finalize)
test_setup(item)
@@ -194,8 +219,9 @@ def pytest_runtest_teardown(item):
def test_setup(item):
- plugin_base.before_test(item, item.parent.module.__name__,
- item.parent.cls, item.name)
+ plugin_base.before_test(
+ item, item.parent.module.__name__, item.parent.cls, item.name
+ )
def test_teardown(item):
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
index fab99b186..3986985c7 100644
--- a/lib/sqlalchemy/testing/profiling.py
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -42,17 +42,16 @@ class ProfileStatsFile(object):
def __init__(self, filename):
self.force_write = (
- config.options is not None and
- config.options.force_write_profiles
+ config.options is not None and config.options.force_write_profiles
)
self.write = self.force_write or (
- config.options is not None and
- config.options.write_profiles
+ config.options is not None and config.options.write_profiles
)
self.fname = os.path.abspath(filename)
self.short_fname = os.path.split(self.fname)[-1]
self.data = collections.defaultdict(
- lambda: collections.defaultdict(dict))
+ lambda: collections.defaultdict(dict)
+ )
self._read()
if self.write:
# rewrite for the case where features changed,
@@ -65,7 +64,7 @@ class ProfileStatsFile(object):
dbapi_key = config.db.name + "_" + config.db.driver
# keep it at 2.7, 3.1, 3.2, etc. for now.
- py_version = '.'.join([str(v) for v in sys.version_info[0:2]])
+ py_version = ".".join([str(v) for v in sys.version_info[0:2]])
platform_tokens = [py_version]
platform_tokens.append(dbapi_key)
@@ -87,8 +86,7 @@ class ProfileStatsFile(object):
def has_stats(self):
test_key = _current_test
return (
- test_key in self.data and
- self.platform_key in self.data[test_key]
+ test_key in self.data and self.platform_key in self.data[test_key]
)
def result(self, callcount):
@@ -96,15 +94,15 @@ class ProfileStatsFile(object):
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
- if 'counts' not in per_platform:
- per_platform['counts'] = counts = []
+ if "counts" not in per_platform:
+ per_platform["counts"] = counts = []
else:
- counts = per_platform['counts']
+ counts = per_platform["counts"]
- if 'current_count' not in per_platform:
- per_platform['current_count'] = current_count = 0
+ if "current_count" not in per_platform:
+ per_platform["current_count"] = current_count = 0
else:
- current_count = per_platform['current_count']
+ current_count = per_platform["current_count"]
has_count = len(counts) > current_count
@@ -114,16 +112,16 @@ class ProfileStatsFile(object):
self._write()
result = None
else:
- result = per_platform['lineno'], counts[current_count]
- per_platform['current_count'] += 1
+ result = per_platform["lineno"], counts[current_count]
+ per_platform["current_count"] += 1
return result
def replace(self, callcount):
test_key = _current_test
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
- counts = per_platform['counts']
- current_count = per_platform['current_count']
+ counts = per_platform["counts"]
+ current_count = per_platform["current_count"]
if current_count < len(counts):
counts[current_count - 1] = callcount
else:
@@ -164,9 +162,9 @@ class ProfileStatsFile(object):
per_fn = self.data[test_key]
per_platform = per_fn[platform_key]
c = [int(count) for count in counts.split(",")]
- per_platform['counts'] = c
- per_platform['lineno'] = lineno + 1
- per_platform['current_count'] = 0
+ per_platform["counts"] = c
+ per_platform["lineno"] = lineno + 1
+ per_platform["current_count"] = 0
profile_f.close()
def _write(self):
@@ -179,7 +177,7 @@ class ProfileStatsFile(object):
profile_f.write("\n# TEST: %s\n\n" % test_key)
for platform_key in sorted(per_fn):
per_platform = per_fn[platform_key]
- c = ",".join(str(count) for count in per_platform['counts'])
+ c = ",".join(str(count) for count in per_platform["counts"])
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
profile_f.close()
@@ -199,7 +197,9 @@ def function_call_count(variance=0.05):
def wrap(*args, **kw):
with count_functions(variance=variance):
return fn(*args, **kw)
+
return update_wrapper(wrap, fn)
+
return decorate
@@ -213,21 +213,22 @@ def count_functions(variance=0.05):
"No profiling stats available on this "
"platform for this function. Run tests with "
"--write-profiles to add statistics to %s for "
- "this platform." % _profile_stats.short_fname)
+ "this platform." % _profile_stats.short_fname
+ )
gc_collect()
pr = cProfile.Profile()
pr.enable()
- #began = time.time()
+ # began = time.time()
yield
- #ended = time.time()
+ # ended = time.time()
pr.disable()
- #s = compat.StringIO()
+ # s = compat.StringIO()
stats = pstats.Stats(pr, stream=sys.stdout)
- #timespent = ended - began
+ # timespent = ended - began
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
@@ -237,11 +238,7 @@ def count_functions(variance=0.05):
else:
line_no, expected_count = expected
- print(("Pstats calls: %d Expected %s" % (
- callcount,
- expected_count
- )
- ))
+ print(("Pstats calls: %d Expected %s" % (callcount, expected_count)))
stats.sort_stats("cumulative")
stats.print_stats()
@@ -259,7 +256,9 @@ def count_functions(variance=0.05):
"--write-profiles to "
"regenerate this callcount."
% (
- callcount, (variance * 100),
- expected_count, _profile_stats.platform_key))
-
-
+ callcount,
+ (variance * 100),
+ expected_count,
+ _profile_stats.platform_key,
+ )
+ )
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index c0ca7c1cb..25028ccb3 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -8,6 +8,7 @@ import collections
import os
import time
import logging
+
log = logging.getLogger(__name__)
FOLLOWER_IDENT = None
@@ -25,6 +26,7 @@ class register(object):
def decorate(fn):
self.fns[dbname] = fn
return self
+
return decorate
def __call__(self, cfg, *arg):
@@ -38,7 +40,7 @@ class register(object):
if backend in self.fns:
return self.fns[backend](cfg, *arg)
else:
- return self.fns['*'](cfg, *arg)
+ return self.fns["*"](cfg, *arg)
def create_follower_db(follower_ident):
@@ -82,9 +84,7 @@ def _configs_for_db_operation():
for cfg in config.Config.all_configs():
url = cfg.db.url
backend = url.get_backend_name()
- host_conf = (
- backend,
- url.username, url.host, url.database)
+ host_conf = (backend, url.username, url.host, url.database)
if host_conf not in hosts:
yield cfg
@@ -128,14 +128,13 @@ def _follower_url_from_main(url, ident):
@_update_db_opts.for_db("mssql")
def _mssql_update_db_opts(db_url, db_opts):
- db_opts['legacy_schema_aliasing'] = False
-
+ db_opts["legacy_schema_aliasing"] = False
@_follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
- if not url.database or url.database == ':memory:':
+ if not url.database or url.database == ":memory:":
return url
else:
return sa_url.make_url("sqlite:///%s.db" % ident)
@@ -151,19 +150,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident):
# as an attached
if not follower_ident:
dbapi_connection.execute(
- 'ATTACH DATABASE "test_schema.db" AS test_schema')
+ 'ATTACH DATABASE "test_schema.db" AS test_schema'
+ )
else:
dbapi_connection.execute(
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
- % follower_ident)
+ % follower_ident
+ )
@_create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
template_db = cfg.options.postgresql_templatedb
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
try:
_pg_drop_db(cfg, conn, ident)
except Exception:
@@ -175,7 +175,8 @@ def _pg_create_db(cfg, eng, ident):
while True:
try:
conn.execute(
- "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db))
+ "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
+ )
except exc.OperationalError as err:
attempt += 1
if attempt >= 3:
@@ -184,8 +185,11 @@ def _pg_create_db(cfg, eng, ident):
log.info(
"Waiting to create %s, URI %r, "
"template DB %s is in use sleeping for .5",
- ident, eng.url, template_db)
- time.sleep(.5)
+ ident,
+ eng.url,
+ template_db,
+ )
+ time.sleep(0.5)
else:
break
@@ -203,9 +207,11 @@ def _mysql_create_db(cfg, eng, ident):
# 1271, u"Illegal mix of collations for operation 'UNION'"
conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident)
conn.execute(
- "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident)
+ "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident
+ )
conn.execute(
- "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident)
+ "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident
+ )
@_configure_follower.for_db("mysql")
@@ -221,14 +227,15 @@ def _sqlite_create_db(cfg, eng, ident):
@_drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.execute(
text(
"select pg_terminate_backend(pid) from pg_stat_activity "
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
- ), dname=ident)
+ ),
+ dname=ident,
+ )
conn.execute("DROP DATABASE %s" % ident)
@@ -257,11 +264,12 @@ def _oracle_create_db(cfg, eng, ident):
conn.execute("create user %s identified by xe" % ident)
conn.execute("create user %s_ts1 identified by xe" % ident)
conn.execute("create user %s_ts2 identified by xe" % ident)
- conn.execute("grant dba to %s" % (ident, ))
+ conn.execute("grant dba to %s" % (ident,))
conn.execute("grant unlimited tablespace to %s" % ident)
conn.execute("grant unlimited tablespace to %s_ts1" % ident)
conn.execute("grant unlimited tablespace to %s_ts2" % ident)
+
@_configure_follower.for_db("oracle")
def _oracle_configure_follower(config, ident):
config.test_schema = "%s_ts1" % ident
@@ -320,6 +328,7 @@ def reap_dbs(idents_file):
elif backend == "mssql":
_reap_mssql_dbs(url, ident)
+
def _reap_oracle_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
@@ -330,8 +339,9 @@ def _reap_oracle_dbs(url, idents):
to_reap = conn.execute(
"select u.username from all_users u where username "
"like 'TEST_%' and not exists (select username "
- "from v$session where username=u.username)")
- all_names = {username.lower() for (username, ) in to_reap}
+ "from v$session where username=u.username)"
+ )
+ all_names = {username.lower() for (username,) in to_reap}
to_drop = set()
for name in all_names:
if name.endswith("_ts1") or name.endswith("_ts2"):
@@ -348,28 +358,28 @@ def _reap_oracle_dbs(url, idents):
if _ora_drop_ignore(conn, username):
dropped += 1
log.info(
- "Dropped %d out of %d stale databases detected",
- dropped, total)
-
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
@_follower_url_from_main.for_db("oracle")
def _oracle_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
url.username = ident
- url.password = 'xe'
+ url.password = "xe"
return url
@_create_db.for_db("mssql")
def _mssql_create_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.execute("create database %s" % ident)
conn.execute(
- "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident)
+ "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
+ )
conn.execute(
- "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident)
+ "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
+ )
conn.execute("use %s" % ident)
conn.execute("create schema test_schema")
conn.execute("create schema test_schema_2")
@@ -377,10 +387,10 @@ def _mssql_create_db(cfg, eng, ident):
@_drop_db.for_db("mssql")
def _mssql_drop_db(cfg, eng, ident):
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
_mssql_drop_ignore(conn, ident)
+
def _mssql_drop_ignore(conn, ident):
try:
# typically when this happens, we can't KILL the session anyway,
@@ -401,8 +411,7 @@ def _mssql_drop_ignore(conn, ident):
def _reap_mssql_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
- with eng.connect().execution_options(
- isolation_level="AUTOCOMMIT") as conn:
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
log.info("identifiers in file: %s", ", ".join(idents))
@@ -410,8 +419,9 @@ def _reap_mssql_dbs(url, idents):
"select d.name from sys.databases as d where name "
"like 'TEST_%' and not exists (select session_id "
"from sys.dm_exec_sessions "
- "where database_id=d.database_id)")
- all_names = {dbname.lower() for (dbname, ) in to_reap}
+ "where database_id=d.database_id)"
+ )
+ all_names = {dbname.lower() for (dbname,) in to_reap}
to_drop = set()
for name in all_names:
if name in idents:
@@ -422,5 +432,5 @@ def _reap_mssql_dbs(url, idents):
if _mssql_drop_ignore(conn, dbname):
dropped += 1
log.info(
- "Dropped %d out of %d stale databases detected",
- dropped, total)
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
diff --git a/lib/sqlalchemy/testing/replay_fixture.py b/lib/sqlalchemy/testing/replay_fixture.py
index b50f52e3d..9832b07a2 100644
--- a/lib/sqlalchemy/testing/replay_fixture.py
+++ b/lib/sqlalchemy/testing/replay_fixture.py
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
class ReplayFixtureTest(fixtures.TestBase):
-
@contextlib.contextmanager
def _dummy_ctx(self, *arg, **kw):
yield
@@ -22,8 +21,8 @@ class ReplayFixtureTest(fixtures.TestBase):
creator = config.db.pool._creator
recorder = lambda: dbapi_session.recorder(creator())
engine = create_engine(
- config.db.url, creator=recorder,
- use_native_hstore=False)
+ config.db.url, creator=recorder, use_native_hstore=False
+ )
self.metadata = MetaData(engine)
self.engine = engine
self.session = Session(engine)
@@ -37,8 +36,8 @@ class ReplayFixtureTest(fixtures.TestBase):
player = lambda: dbapi_session.player()
engine = create_engine(
- config.db.url, creator=player,
- use_native_hstore=False)
+ config.db.url, creator=player, use_native_hstore=False
+ )
self.metadata = MetaData(engine)
self.engine = engine
@@ -74,21 +73,49 @@ class ReplayableSession(object):
NoAttribute = object()
if util.py2k:
- Natives = set([getattr(types, t)
- for t in dir(types) if not t.startswith('_')]).\
- difference([getattr(types, t)
- for t in ('FunctionType', 'BuiltinFunctionType',
- 'MethodType', 'BuiltinMethodType',
- 'LambdaType', 'UnboundMethodType',)])
+ Natives = set(
+ [getattr(types, t) for t in dir(types) if not t.startswith("_")]
+ ).difference(
+ [
+ getattr(types, t)
+ for t in (
+ "FunctionType",
+ "BuiltinFunctionType",
+ "MethodType",
+ "BuiltinMethodType",
+ "LambdaType",
+ "UnboundMethodType",
+ )
+ ]
+ )
else:
- Natives = set([getattr(types, t)
- for t in dir(types) if not t.startswith('_')]).\
- union([type(t) if not isinstance(t, type)
- else t for t in __builtins__.values()]).\
- difference([getattr(types, t)
- for t in ('FunctionType', 'BuiltinFunctionType',
- 'MethodType', 'BuiltinMethodType',
- 'LambdaType', )])
+ Natives = (
+ set(
+ [
+ getattr(types, t)
+ for t in dir(types)
+ if not t.startswith("_")
+ ]
+ )
+ .union(
+ [
+ type(t) if not isinstance(t, type) else t
+ for t in __builtins__.values()
+ ]
+ )
+ .difference(
+ [
+ getattr(types, t)
+ for t in (
+ "FunctionType",
+ "BuiltinFunctionType",
+ "MethodType",
+ "BuiltinMethodType",
+ "LambdaType",
+ )
+ ]
+ )
+ )
def __init__(self):
self.buffer = deque()
@@ -105,8 +132,10 @@ class ReplayableSession(object):
self._subject = subject
def __call__(self, *args, **kw):
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
+ subject, buffer = [
+ object.__getattribute__(self, x)
+ for x in ("_subject", "_buffer")
+ ]
result = subject(*args, **kw)
if type(result) not in ReplayableSession.Natives:
@@ -126,8 +155,10 @@ class ReplayableSession(object):
except AttributeError:
pass
- subject, buffer = [object.__getattribute__(self, x)
- for x in ('_subject', '_buffer')]
+ subject, buffer = [
+ object.__getattribute__(self, x)
+ for x in ("_subject", "_buffer")
+ ]
try:
result = type(subject).__getattribute__(subject, key)
except AttributeError:
@@ -146,7 +177,7 @@ class ReplayableSession(object):
self._buffer = buffer
def __call__(self, *args, **kw):
- buffer = object.__getattribute__(self, '_buffer')
+ buffer = object.__getattribute__(self, "_buffer")
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
@@ -162,7 +193,7 @@ class ReplayableSession(object):
return object.__getattribute__(self, key)
except AttributeError:
pass
- buffer = object.__getattribute__(self, '_buffer')
+ buffer = object.__getattribute__(self, "_buffer")
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 58df643f4..c96d26d32 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -26,7 +26,6 @@ class Requirements(object):
class SuiteRequirements(Requirements):
-
@property
def create_table(self):
"""target platform can emit basic CreateTable DDL."""
@@ -68,8 +67,8 @@ class SuiteRequirements(Requirements):
# somehow only_if([x, y]) isn't working here, negation/conjunctions
# getting confused.
return exclusions.only_if(
- lambda: self.on_update_cascade.enabled or
- self.deferrable_fks.enabled
+ lambda: self.on_update_cascade.enabled
+ or self.deferrable_fks.enabled
)
@property
@@ -231,22 +230,21 @@ class SuiteRequirements(Requirements):
def sane_rowcount(self):
return exclusions.skip_if(
lambda config: not config.db.dialect.supports_sane_rowcount,
- "driver doesn't support 'sane' rowcount"
+ "driver doesn't support 'sane' rowcount",
)
@property
def sane_multi_rowcount(self):
return exclusions.fails_if(
lambda config: not config.db.dialect.supports_sane_multi_rowcount,
- "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
+ "driver %(driver)s %(doesnt_support)s 'sane' multi row count",
)
@property
def sane_rowcount_w_returning(self):
return exclusions.fails_if(
- lambda config:
- not config.db.dialect.supports_sane_rowcount_returning,
- "driver doesn't support 'sane' rowcount when returning is on"
+ lambda config: not config.db.dialect.supports_sane_rowcount_returning,
+ "driver doesn't support 'sane' rowcount when returning is on",
)
@property
@@ -255,9 +253,9 @@ class SuiteRequirements(Requirements):
INSERT DEFAULT VALUES or equivalent."""
return exclusions.only_if(
- lambda config: config.db.dialect.supports_empty_insert or
- config.db.dialect.supports_default_values,
- "empty inserts not supported"
+ lambda config: config.db.dialect.supports_empty_insert
+ or config.db.dialect.supports_default_values,
+ "empty inserts not supported",
)
@property
@@ -272,7 +270,7 @@ class SuiteRequirements(Requirements):
return exclusions.only_if(
lambda config: config.db.dialect.implicit_returning,
- "%(database)s %(does_support)s 'returning'"
+ "%(database)s %(does_support)s 'returning'",
)
@property
@@ -297,7 +295,7 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(
lambda config: not config.db.dialect.requires_name_normalize,
- "Backend does not require denormalized names."
+ "Backend does not require denormalized names.",
)
@property
@@ -307,7 +305,7 @@ class SuiteRequirements(Requirements):
return exclusions.skip_if(
lambda config: not config.db.dialect.supports_multivalues_insert,
- "Backend does not support multirow inserts."
+ "Backend does not support multirow inserts.",
)
@property
@@ -355,27 +353,32 @@ class SuiteRequirements(Requirements):
def server_side_cursors(self):
"""Target dialect must support server side cursors."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_server_side_cursors
- ], "no server side cursors support")
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_server_side_cursors],
+ "no server side cursors support",
+ )
@property
def sequences(self):
"""Target database must support SEQUENCEs."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_sequences
- ], "no sequence support")
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_sequences],
+ "no sequence support",
+ )
@property
def sequences_optional(self):
"""Target database supports sequences, but also optionally
as a means of generating new PK values."""
- return exclusions.only_if([
- lambda config: config.db.dialect.supports_sequences and
- config.db.dialect.sequences_optional
- ], "no sequence support, or sequences not optional")
+ return exclusions.only_if(
+ [
+ lambda config: config.db.dialect.supports_sequences
+ and config.db.dialect.sequences_optional
+ ],
+ "no sequence support, or sequences not optional",
+ )
@property
def reflects_pk_names(self):
@@ -841,7 +844,8 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
- lambda config: config.options.low_connections)
+ lambda config: config.options.low_connections
+ )
@property
def timing_intensive(self):
@@ -859,37 +863,37 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
lambda config: util.py3k and config.options.has_coverage,
- "Stability issues with coverage + py3k"
+ "Stability issues with coverage + py3k",
)
@property
def python2(self):
return exclusions.skip_if(
lambda: sys.version_info >= (3,),
- "Python version 2.xx is required."
+ "Python version 2.xx is required.",
)
@property
def python3(self):
return exclusions.skip_if(
- lambda: sys.version_info < (3,),
- "Python version 3.xx is required."
+ lambda: sys.version_info < (3,), "Python version 3.xx is required."
)
@property
def cpython(self):
return exclusions.only_if(
- lambda: util.cpython,
- "cPython interpreter needed"
+ lambda: util.cpython, "cPython interpreter needed"
)
@property
def non_broken_pickle(self):
from sqlalchemy.util import pickle
+
return exclusions.only_if(
- lambda: not util.pypy and pickle.__name__ == 'cPickle'
- or sys.version_info >= (3, 2),
- "Needs cPickle+cPython or newer Python 3 pickle"
+ lambda: not util.pypy
+ and pickle.__name__ == "cPickle"
+ or sys.version_info >= (3, 2),
+ "Needs cPickle+cPython or newer Python 3 pickle",
)
@property
@@ -910,7 +914,7 @@ class SuiteRequirements(Requirements):
"""
return exclusions.skip_if(
lambda config: config.options.has_coverage,
- "Issues observed when coverage is enabled"
+ "Issues observed when coverage is enabled",
)
def _has_mysql_on_windows(self, config):
@@ -931,8 +935,9 @@ class SuiteRequirements(Requirements):
def _has_sqlite(self):
from sqlalchemy import create_engine
+
try:
- create_engine('sqlite://')
+ create_engine("sqlite://")
return True
except ImportError:
return False
@@ -940,6 +945,7 @@ class SuiteRequirements(Requirements):
def _has_cextensions(self):
try:
from sqlalchemy import cresultproxy, cprocessors
+
return True
except ImportError:
return False
diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py
index 87be0749c..6aa820fd5 100644
--- a/lib/sqlalchemy/testing/runner.py
+++ b/lib/sqlalchemy/testing/runner.py
@@ -47,4 +47,4 @@ def setup_py_test():
to nose.
"""
- nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner'])
+ nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"])
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
index 401c8cbb7..b345a9487 100644
--- a/lib/sqlalchemy/testing/schema.py
+++ b/lib/sqlalchemy/testing/schema.py
@@ -9,7 +9,7 @@ from . import exclusions
from .. import schema, event
from . import config
-__all__ = 'Table', 'Column',
+__all__ = "Table", "Column"
table_options = {}
@@ -17,30 +17,35 @@ table_options = {}
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
- test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')}
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
kw.update(table_options)
- if exclusions.against(config._current, 'mysql'):
- if 'mysql_engine' not in kw and 'mysql_type' not in kw and \
- "autoload_with" not in kw:
- if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
- kw['mysql_engine'] = 'InnoDB'
+ if exclusions.against(config._current, "mysql"):
+ if (
+ "mysql_engine" not in kw
+ and "mysql_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mysql_engine"] = "InnoDB"
else:
- kw['mysql_engine'] = 'MyISAM'
+ kw["mysql_engine"] = "MyISAM"
# Apply some default cascading rules for self-referential foreign keys.
# MySQL InnoDB has some issues around seleting self-refs too.
- if exclusions.against(config._current, 'firebird'):
+ if exclusions.against(config._current, "firebird"):
table_name = args[0]
- unpack = (config.db.dialect.
- identifier_preparer.unformat_identifiers)
+ unpack = config.db.dialect.identifier_preparer.unformat_identifiers
# Only going after ForeignKeys in Columns. May need to
# expand to ForeignKeyConstraint too.
- fks = [fk
- for col in args if isinstance(col, schema.Column)
- for fk in col.foreign_keys]
+ fks = [
+ fk
+ for col in args
+ if isinstance(col, schema.Column)
+ for fk in col.foreign_keys
+ ]
for fk in fks:
# root around in raw spec
@@ -54,9 +59,9 @@ def Table(*args, **kw):
name = unpack(ref)[0]
if name == table_name:
if fk.ondelete is None:
- fk.ondelete = 'CASCADE'
+ fk.ondelete = "CASCADE"
if fk.onupdate is None:
- fk.onupdate = 'CASCADE'
+ fk.onupdate = "CASCADE"
return schema.Table(*args, **kw)
@@ -64,37 +69,46 @@ def Table(*args, **kw):
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
- test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')}
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
col = schema.Column(*args, **kw)
- if test_opts.get('test_needs_autoincrement', False) and \
- kw.get('primary_key', False):
+ if test_opts.get("test_needs_autoincrement", False) and kw.get(
+ "primary_key", False
+ ):
if col.default is None and col.server_default is None:
col.autoincrement = True
# allow any test suite to pick up on this
- col.info['test_needs_autoincrement'] = True
+ col.info["test_needs_autoincrement"] = True
# hardcoded rule for firebird, oracle; this should
# be moved out
- if exclusions.against(config._current, 'firebird', 'oracle'):
+ if exclusions.against(config._current, "firebird", "oracle"):
+
def add_seq(c, tbl):
c._init_items(
- schema.Sequence(_truncate_name(
- config.db.dialect, tbl.name + '_' + c.name + '_seq'),
- optional=True)
+ schema.Sequence(
+ _truncate_name(
+ config.db.dialect, tbl.name + "_" + c.name + "_seq"
+ ),
+ optional=True,
+ )
)
- event.listen(col, 'after_parent_attach', add_seq, propagate=True)
+
+ event.listen(col, "after_parent_attach", add_seq, propagate=True)
return col
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
- return name[0:max(dialect.max_identifier_length - 6, 0)] + \
- "_" + hex(hash(name) % 64)[2:]
+ return (
+ name[0 : max(dialect.max_identifier_length - 6, 0)]
+ + "_"
+ + hex(hash(name) % 64)[2:]
+ )
else:
return name
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
index 748d9722d..a4e142c5a 100644
--- a/lib/sqlalchemy/testing/suite/__init__.py
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -1,4 +1,3 @@
-
from sqlalchemy.testing.suite.test_cte import *
from sqlalchemy.testing.suite.test_dialect import *
from sqlalchemy.testing.suite.test_ddl import *
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
index cc72278e6..d2f35933b 100644
--- a/lib/sqlalchemy/testing/suite/test_cte.py
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -10,22 +10,28 @@ from ..schema import Table, Column
class CTETest(fixtures.TablesTest):
__backend__ = True
- __requires__ = 'ctes',
+ __requires__ = ("ctes",)
- run_inserts = 'each'
- run_deletes = 'each'
+ run_inserts = "each"
+ run_deletes = "each"
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column("parent_id", ForeignKey("some_table.id")))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", ForeignKey("some_table.id")),
+ )
- Table("some_other_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)),
- Column("parent_id", Integer))
+ Table(
+ "some_other_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -36,28 +42,33 @@ class CTETest(fixtures.TablesTest):
{"id": 2, "data": "d2", "parent_id": 1},
{"id": 3, "data": "d3", "parent_id": 1},
{"id": 4, "data": "d4", "parent_id": 3},
- {"id": 5, "data": "d5", "parent_id": 3}
- ]
+ {"id": 5, "data": "d5", "parent_id": 3},
+ ],
)
def test_select_nonrecursive_round_trip(self):
some_table = self.tables.some_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
result = conn.execute(
select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
)
- eq_(result.fetchall(), [("d4", )])
+ eq_(result.fetchall(), [("d4",)])
def test_select_recursive_round_trip(self):
some_table = self.tables.some_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])).cte(
- "some_cte", recursive=True)
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte", recursive=True)
+ )
cte_alias = cte.alias("c1")
st1 = some_table.alias()
@@ -67,12 +78,13 @@ class CTETest(fixtures.TablesTest):
select([st1]).where(st1.c.id == cte_alias.c.parent_id)
)
result = conn.execute(
- select([cte.c.data]).where(
- cte.c.data != "d2").order_by(cte.c.data.desc())
+ select([cte.c.data])
+ .where(cte.c.data != "d2")
+ .order_by(cte.c.data.desc())
)
eq_(
result.fetchall(),
- [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
+ [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
)
def test_insert_from_select_round_trip(self):
@@ -80,20 +92,21 @@ class CTETest(fixtures.TablesTest):
some_other_table = self.tables.some_other_table
with config.db.connect() as conn:
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.insert().from_select(
- ["id", "data", "parent_id"],
- select([cte])
+ ["id", "data", "parent_id"], select([cte])
)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
)
@testing.requires.ctes_with_update_delete
@@ -105,27 +118,31 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
- some_other_table.update().values(parent_id=5).where(
- some_other_table.c.data == cte.c.data
- )
+ some_other_table.update()
+ .values(parent_id=5)
+ .where(some_other_table.c.data == cte.c.data)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
[
- (1, "d1", None), (2, "d2", 5),
- (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
- ]
+ (1, "d1", None),
+ (2, "d2", 5),
+ (3, "d3", 5),
+ (4, "d4", 5),
+ (5, "d5", 3),
+ ],
)
@testing.requires.ctes_with_update_delete
@@ -137,14 +154,15 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.delete().where(
some_other_table.c.data == cte.c.data
@@ -154,9 +172,7 @@ class CTETest(fixtures.TablesTest):
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [
- (1, "d1", None), (5, "d5", 3)
- ]
+ [(1, "d1", None), (5, "d5", 3)],
)
@testing.requires.ctes_with_update_delete
@@ -168,26 +184,26 @@ class CTETest(fixtures.TablesTest):
with config.db.connect() as conn:
conn.execute(
some_other_table.insert().from_select(
- ['id', 'data', 'parent_id'],
- select([some_table])
+ ["id", "data", "parent_id"], select([some_table])
)
)
- cte = select([some_table]).where(
- some_table.c.data.in_(["d2", "d3", "d4"])
- ).cte("some_cte")
+ cte = (
+ select([some_table])
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
conn.execute(
some_other_table.delete().where(
- some_other_table.c.data ==
- select([cte.c.data]).where(
- cte.c.id == some_other_table.c.id)
+ some_other_table.c.data
+ == select([cte.c.data]).where(
+ cte.c.id == some_other_table.c.id
+ )
)
)
eq_(
conn.execute(
select([some_other_table]).order_by(some_other_table.c.id)
).fetchall(),
- [
- (1, "d1", None), (5, "d5", 3)
- ]
+ [(1, "d1", None), (5, "d5", 3)],
)
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
index 1d8010c8a..7c44388d4 100644
--- a/lib/sqlalchemy/testing/suite/test_ddl.py
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -1,5 +1,3 @@
-
-
from .. import fixtures, config, util
from ..config import requirements
from ..assertions import eq_
@@ -11,55 +9,47 @@ class TableDDLTest(fixtures.TestBase):
__backend__ = True
def _simple_fixture(self):
- return Table('test_table', self.metadata,
- Column('id', Integer, primary_key=True,
- autoincrement=False),
- Column('data', String(50))
- )
+ return Table(
+ "test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
def _underscore_fixture(self):
- return Table('_test_table', self.metadata,
- Column('id', Integer, primary_key=True,
- autoincrement=False),
- Column('_data', String(50))
- )
+ return Table(
+ "_test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("_data", String(50)),
+ )
def _simple_roundtrip(self, table):
with config.db.begin() as conn:
- conn.execute(table.insert().values((1, 'some data')))
+ conn.execute(table.insert().values((1, "some data")))
result = conn.execute(table.select())
- eq_(
- result.first(),
- (1, 'some data')
- )
+ eq_(result.first(), (1, "some data"))
@requirements.create_table
@util.provide_metadata
def test_create_table(self):
table = self._simple_fixture()
- table.create(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
@requirements.drop_table
@util.provide_metadata
def test_drop_table(self):
table = self._simple_fixture()
- table.create(
- config.db, checkfirst=False
- )
- table.drop(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
+ table.drop(config.db, checkfirst=False)
@requirements.create_table
@util.provide_metadata
def test_underscore_names(self):
table = self._underscore_fixture()
- table.create(
- config.db, checkfirst=False
- )
+ table.create(config.db, checkfirst=False)
self._simple_roundtrip(table)
-__all__ = ('TableDDLTest', )
+
+__all__ = ("TableDDLTest",)
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
index 2c5dd0e36..5e589f3b8 100644
--- a/lib/sqlalchemy/testing/suite/test_dialect.py
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -15,16 +15,19 @@ class ExceptionTest(fixtures.TablesTest):
specific exceptions from real round trips, we need to be conservative.
"""
- run_deletes = 'each'
+
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
@requirements.duplicate_key_raises_integrity_error
def test_integrity_error(self):
@@ -33,15 +36,14 @@ class ExceptionTest(fixtures.TablesTest):
trans = conn.begin()
conn.execute(
- self.tables.manual_pk.insert(),
- {'id': 1, 'data': 'd1'}
+ self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
)
assert_raises(
exc.IntegrityError,
conn.execute,
self.tables.manual_pk.insert(),
- {'id': 1, 'data': 'd1'}
+ {"id": 1, "data": "d1"},
)
trans.rollback()
@@ -49,38 +51,39 @@ class ExceptionTest(fixtures.TablesTest):
class AutocommitTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
- __requires__ = 'autocommit',
+ __requires__ = ("autocommit",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('some_table', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50)),
- test_needs_acid=True
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ test_needs_acid=True,
+ )
def _test_conn_autocommits(self, conn, autocommit):
trans = conn.begin()
conn.execute(
- self.tables.some_table.insert(),
- {"id": 1, "data": "some data"}
+ self.tables.some_table.insert(), {"id": 1, "data": "some data"}
)
trans.rollback()
eq_(
conn.scalar(select([self.tables.some_table.c.id])),
- 1 if autocommit else None
+ 1 if autocommit else None,
)
conn.execute(self.tables.some_table.delete())
def test_autocommit_on(self):
conn = config.db.connect()
- c2 = conn.execution_options(isolation_level='AUTOCOMMIT')
+ c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
self._test_conn_autocommits(c2, True)
conn.invalidate()
self._test_conn_autocommits(conn, False)
@@ -98,7 +101,7 @@ class EscapingTest(fixtures.TestBase):
"""
m = self.metadata
- t = Table('t', m, Column('data', String(50)))
+ t = Table("t", m, Column("data", String(50)))
t.create(config.db)
with config.db.begin() as conn:
conn.execute(t.insert(), dict(data="some % value"))
@@ -107,14 +110,17 @@ class EscapingTest(fixtures.TestBase):
eq_(
conn.scalar(
select([t.c.data]).where(
- t.c.data == literal_column("'some % value'"))
+ t.c.data == literal_column("'some % value'")
+ )
),
- "some % value"
+ "some % value",
)
eq_(
conn.scalar(
select([t.c.data]).where(
- t.c.data == literal_column("'some %% other value'"))
- ), "some %% other value"
+ t.c.data == literal_column("'some %% other value'")
+ )
+ ),
+ "some %% other value",
)
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
index c0b6b18eb..6257451eb 100644
--- a/lib/sqlalchemy/testing/suite/test_insert.py
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -10,53 +10,48 @@ from ..schema import Table, Column
class LastrowidTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
- __requires__ = 'implements_get_lastrowid', 'autoincrement_insert'
+ __requires__ = "implements_get_lastrowid", "autoincrement_insert"
__engine_options__ = {"implicit_returning": False}
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (config.db.dialect.default_sequence_base, "some data")
- )
+ eq_(row, (config.db.dialect.default_sequence_base, "some data"))
def test_autoincrement_on_insert(self):
- config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.autoinc_pk.insert(), data="some data")
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- r.inserted_primary_key,
- [pk]
- )
+ eq_(r.inserted_primary_key, [pk])
# failed on pypy1.9 but seems to be OK on pypy 2.1
# @exclusions.fails_if(lambda: util.pypy,
@@ -65,50 +60,57 @@ class LastrowidTest(fixtures.TablesTest):
@requirements.dbapi_lastrowid
def test_native_lastrowid_autoinc(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
lastrowid = r.lastrowid
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- lastrowid, pk
- )
+ eq_(lastrowid, pk)
class InsertBehaviorTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
- Table('manual_pk', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('data', String(50))
- )
- Table('includes_defaults', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50)),
- Column('x', Integer, default=5),
- Column('y', Integer,
- default=literal_column("2", type_=Integer) + literal(2)))
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+ Table(
+ "includes_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ Column("x", Integer, default=5),
+ Column(
+ "y",
+ Integer,
+ default=literal_column("2", type_=Integer) + literal(2),
+ ),
+ )
def test_autoclose_on_insert(self):
if requirements.returning.enabled:
engine = engines.testing_engine(
- options={'implicit_returning': False})
+ options={"implicit_returning": False}
+ )
else:
engine = config.db
- r = engine.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ r = engine.execute(self.tables.autoinc_pk.insert(), data="some data")
assert r._soft_closed
assert not r.closed
assert r.is_insert
@@ -117,8 +119,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
@requirements.returning
def test_autoclose_on_insert_implicit_returning(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
assert r._soft_closed
assert not r.closed
@@ -127,15 +128,14 @@ class InsertBehaviorTest(fixtures.TablesTest):
@requirements.empty_inserts
def test_empty_insert(self):
- r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- )
+ r = config.db.execute(self.tables.autoinc_pk.insert())
assert r._soft_closed
assert not r.closed
r = config.db.execute(
- self.tables.autoinc_pk.select().
- where(self.tables.autoinc_pk.c.id != None)
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
)
assert len(r.fetchall())
@@ -150,15 +150,15 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
result = config.db.execute(
- dest_table.insert().
- from_select(
+ dest_table.insert().from_select(
("data",),
- select([src_table.c.data]).
- where(src_table.c.data.in_(["data2", "data3"]))
+ select([src_table.c.data]).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
)
)
@@ -167,7 +167,7 @@ class InsertBehaviorTest(fixtures.TablesTest):
result = config.db.execute(
select([dest_table.c.data]).order_by(dest_table.c.data)
)
- eq_(result.fetchall(), [("data2", ), ("data3", )])
+ eq_(result.fetchall(), [("data2",), ("data3",)])
@requirements.insert_from_select
def test_insert_from_select_autoinc_no_rows(self):
@@ -175,11 +175,11 @@ class InsertBehaviorTest(fixtures.TablesTest):
dest_table = self.tables.autoinc_pk
result = config.db.execute(
- dest_table.insert().
- from_select(
+ dest_table.insert().from_select(
("data",),
- select([src_table.c.data]).
- where(src_table.c.data.in_(["data2", "data3"]))
+ select([src_table.c.data]).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
)
)
eq_(result.inserted_primary_key, [None])
@@ -199,23 +199,23 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
config.db.execute(
- table.insert(inline=True).
- from_select(("id", "data",),
- select([table.c.id + 5, table.c.data]).
- where(table.c.data.in_(["data2", "data3"]))
- ),
+ table.insert(inline=True).from_select(
+ ("id", "data"),
+ select([table.c.id + 5, table.c.data]).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
)
eq_(
config.db.execute(
select([table.c.data]).order_by(table.c.data)
).fetchall(),
- [("data1", ), ("data2", ), ("data2", ),
- ("data3", ), ("data3", )]
+ [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
)
@requirements.insert_from_select
@@ -227,56 +227,60 @@ class InsertBehaviorTest(fixtures.TablesTest):
dict(id=1, data="data1"),
dict(id=2, data="data2"),
dict(id=3, data="data3"),
- ]
+ ],
)
config.db.execute(
- table.insert(inline=True).
- from_select(("id", "data",),
- select([table.c.id + 5, table.c.data]).
- where(table.c.data.in_(["data2", "data3"]))
- ),
+ table.insert(inline=True).from_select(
+ ("id", "data"),
+ select([table.c.id + 5, table.c.data]).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
)
eq_(
config.db.execute(
select([table]).order_by(table.c.data, table.c.id)
).fetchall(),
- [(1, 'data1', 5, 4), (2, 'data2', 5, 4),
- (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
+ [
+ (1, "data1", 5, 4),
+ (2, "data2", 5, 4),
+ (7, "data2", 5, 4),
+ (3, "data3", 5, 4),
+ (8, "data3", 5, 4),
+ ],
)
class ReturningTest(fixtures.TablesTest):
- run_create_tables = 'each'
- __requires__ = 'returning', 'autoincrement_insert'
+ run_create_tables = "each"
+ __requires__ = "returning", "autoincrement_insert"
__backend__ = True
__engine_options__ = {"implicit_returning": True}
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (config.db.dialect.default_sequence_base, "some data")
- )
+ eq_(row, (config.db.dialect.default_sequence_base, "some data"))
@classmethod
def define_tables(cls, metadata):
- Table('autoinc_pk', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('data', String(50))
- )
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
@requirements.fetch_rows_post_commit
def test_explicit_returning_pk_autocommit(self):
engine = config.db
table = self.tables.autoinc_pk
r = engine.execute(
- table.insert().returning(
- table.c.id),
- data="some data"
+ table.insert().returning(table.c.id), data="some data"
)
pk = r.first()[0]
fetched_pk = config.db.scalar(select([table.c.id]))
@@ -287,9 +291,7 @@ class ReturningTest(fixtures.TablesTest):
table = self.tables.autoinc_pk
with engine.begin() as conn:
r = conn.execute(
- table.insert().returning(
- table.c.id),
- data="some data"
+ table.insert().returning(table.c.id), data="some data"
)
pk = r.first()[0]
fetched_pk = config.db.scalar(select([table.c.id]))
@@ -297,23 +299,16 @@ class ReturningTest(fixtures.TablesTest):
def test_autoincrement_on_insert_implcit_returning(self):
- config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.autoinc_pk.insert(), data="some data")
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id_implicit_returning(self):
r = config.db.execute(
- self.tables.autoinc_pk.insert(),
- data="some data"
+ self.tables.autoinc_pk.insert(), data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
- eq_(
- r.inserted_primary_key,
- [pk]
- )
+ eq_(r.inserted_primary_key, [pk])
-__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest')
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 00a5aac01..bfed5f1ab 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -1,5 +1,3 @@
-
-
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
from sqlalchemy import types as sql_types
@@ -26,10 +24,12 @@ class HasTableTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('test_table', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
def test_has_table(self):
with config.db.begin() as conn:
@@ -46,8 +46,10 @@ class ComponentReflectionTest(fixtures.TablesTest):
def setup_bind(cls):
if config.requirements.independent_connections.enabled:
from sqlalchemy import pool
+
return engines.testing_engine(
- options=dict(poolclass=pool.StaticPool))
+ options=dict(poolclass=pool.StaticPool)
+ )
else:
return config.db
@@ -65,86 +67,109 @@ class ComponentReflectionTest(fixtures.TablesTest):
schema_prefix = ""
if testing.requires.self_referential_foreign_keys.enabled:
- users = Table('users', metadata,
- Column('user_id', sa.INT, primary_key=True),
- Column('test1', sa.CHAR(5), nullable=False),
- Column('test2', sa.Float(5), nullable=False),
- Column('parent_user_id', sa.Integer,
- sa.ForeignKey('%susers.user_id' %
- schema_prefix,
- name='user_id_fk')),
- schema=schema,
- test_needs_fk=True,
- )
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ Column(
+ "parent_user_id",
+ sa.Integer,
+ sa.ForeignKey(
+ "%susers.user_id" % schema_prefix, name="user_id_fk"
+ ),
+ ),
+ schema=schema,
+ test_needs_fk=True,
+ )
else:
- users = Table('users', metadata,
- Column('user_id', sa.INT, primary_key=True),
- Column('test1', sa.CHAR(5), nullable=False),
- Column('test2', sa.Float(5), nullable=False),
- schema=schema,
- test_needs_fk=True,
- )
-
- Table("dingalings", metadata,
- Column('dingaling_id', sa.Integer, primary_key=True),
- Column('address_id', sa.Integer,
- sa.ForeignKey('%semail_addresses.address_id' %
- schema_prefix)),
- Column('data', sa.String(30)),
- schema=schema,
- test_needs_fk=True,
- )
- Table('email_addresses', metadata,
- Column('address_id', sa.Integer),
- Column('remote_user_id', sa.Integer,
- sa.ForeignKey(users.c.user_id)),
- Column('email_address', sa.String(20)),
- sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'),
- schema=schema,
- test_needs_fk=True,
- )
- Table('comment_test', metadata,
- Column('id', sa.Integer, primary_key=True, comment='id comment'),
- Column('data', sa.String(20), comment='data % comment'),
- Column(
- 'd2', sa.String(20),
- comment=r"""Comment types type speedily ' " \ '' Fun!"""),
- schema=schema,
- comment=r"""the test % ' " \ table comment""")
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ Table(
+ "dingalings",
+ metadata,
+ Column("dingaling_id", sa.Integer, primary_key=True),
+ Column(
+ "address_id",
+ sa.Integer,
+ sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ ),
+ Column("data", sa.String(30)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "email_addresses",
+ metadata,
+ Column("address_id", sa.Integer),
+ Column(
+ "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
+ ),
+ Column("email_address", sa.String(20)),
+ sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "comment_test",
+ metadata,
+ Column("id", sa.Integer, primary_key=True, comment="id comment"),
+ Column("data", sa.String(20), comment="data % comment"),
+ Column(
+ "d2",
+ sa.String(20),
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ schema=schema,
+ comment=r"""the test % ' " \ table comment""",
+ )
if testing.requires.cross_schema_fk_reflection.enabled:
if schema is None:
Table(
- 'local_table', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('data', sa.String(20)),
+ "local_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
Column(
- 'remote_id',
+ "remote_id",
ForeignKey(
- '%s.remote_table_2.id' %
- testing.config.test_schema)
+ "%s.remote_table_2.id" % testing.config.test_schema
+ ),
),
test_needs_fk=True,
- schema=config.db.dialect.default_schema_name
+ schema=config.db.dialect.default_schema_name,
)
else:
Table(
- 'remote_table', metadata,
- Column('id', sa.Integer, primary_key=True),
+ "remote_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
Column(
- 'local_id',
+ "local_id",
ForeignKey(
- '%s.local_table.id' %
- config.db.dialect.default_schema_name)
+ "%s.local_table.id"
+ % config.db.dialect.default_schema_name
+ ),
),
- Column('data', sa.String(20)),
+ Column("data", sa.String(20)),
schema=schema,
test_needs_fk=True,
)
Table(
- 'remote_table_2', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('data', sa.String(20)),
+ "remote_table_2",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
schema=schema,
test_needs_fk=True,
)
@@ -155,19 +180,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
if not schema:
# test_needs_fk is at the moment to force MySQL InnoDB
noncol_idx_test_nopk = Table(
- 'noncol_idx_test_nopk', metadata,
- Column('q', sa.String(5)),
+ "noncol_idx_test_nopk",
+ metadata,
+ Column("q", sa.String(5)),
test_needs_fk=True,
)
noncol_idx_test_pk = Table(
- 'noncol_idx_test_pk', metadata,
- Column('id', sa.Integer, primary_key=True),
- Column('q', sa.String(5)),
+ "noncol_idx_test_pk",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("q", sa.String(5)),
test_needs_fk=True,
)
- Index('noncol_idx_nopk', noncol_idx_test_nopk.c.q.desc())
- Index('noncol_idx_pk', noncol_idx_test_pk.c.q.desc())
+ Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
+ Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
if testing.requires.view_column_reflection.enabled:
cls.define_views(metadata, schema)
@@ -180,34 +207,35 @@ class ComponentReflectionTest(fixtures.TablesTest):
# temp table fixture
if testing.against("oracle"):
kw = {
- 'prefixes': ["GLOBAL TEMPORARY"],
- 'oracle_on_commit': 'PRESERVE ROWS'
+ "prefixes": ["GLOBAL TEMPORARY"],
+ "oracle_on_commit": "PRESERVE ROWS",
}
else:
- kw = {
- 'prefixes': ["TEMPORARY"],
- }
+ kw = {"prefixes": ["TEMPORARY"]}
user_tmp = Table(
- "user_tmp", metadata,
+ "user_tmp",
+ metadata,
Column("id", sa.INT, primary_key=True),
- Column('name', sa.VARCHAR(50)),
- Column('foo', sa.INT),
- sa.UniqueConstraint('name', name='user_tmp_uq'),
+ Column("name", sa.VARCHAR(50)),
+ Column("foo", sa.INT),
+ sa.UniqueConstraint("name", name="user_tmp_uq"),
sa.Index("user_tmp_ix", "foo"),
**kw
)
- if testing.requires.view_reflection.enabled and \
- testing.requires.temporary_views.enabled:
- event.listen(
- user_tmp, "after_create",
- DDL("create temporary view user_tmp_v as "
- "select * from user_tmp")
- )
+ if (
+ testing.requires.view_reflection.enabled
+ and testing.requires.temporary_views.enabled
+ ):
event.listen(
- user_tmp, "before_drop",
- DDL("drop view user_tmp_v")
+ user_tmp,
+ "after_create",
+ DDL(
+ "create temporary view user_tmp_v as "
+ "select * from user_tmp"
+ ),
)
+ event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
@classmethod
def define_index(cls, metadata, users):
@@ -216,23 +244,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
@classmethod
def define_views(cls, metadata, schema):
- for table_name in ('users', 'email_addresses'):
+ for table_name in ("users", "email_addresses"):
fullname = table_name
if schema:
fullname = "%s.%s" % (schema, table_name)
- view_name = fullname + '_v'
+ view_name = fullname + "_v"
query = "CREATE VIEW %s AS SELECT * FROM %s" % (
- view_name, fullname)
-
- event.listen(
- metadata,
- "after_create",
- DDL(query)
+ view_name,
+ fullname,
)
+
+ event.listen(metadata, "after_create", DDL(query))
event.listen(
- metadata,
- "before_drop",
- DDL("DROP VIEW %s" % view_name)
+ metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
)
@testing.requires.schema_reflection
@@ -244,9 +268,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schema_reflection
def test_dialect_initialize(self):
engine = engines.testing_engine()
- assert not hasattr(engine.dialect, 'default_schema_name')
+ assert not hasattr(engine.dialect, "default_schema_name")
inspect(engine)
- assert hasattr(engine.dialect, 'default_schema_name')
+ assert hasattr(engine.dialect, "default_schema_name")
@testing.requires.schema_reflection
def test_get_default_schema_name(self):
@@ -254,40 +278,49 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(insp.default_schema_name, testing.db.dialect.default_schema_name)
@testing.provide_metadata
- def _test_get_table_names(self, schema=None, table_type='table',
- order_by=None):
+ def _test_get_table_names(
+ self, schema=None, table_type="table", order_by=None
+ ):
_ignore_tables = [
- 'comment_test', 'noncol_idx_test_pk', 'noncol_idx_test_nopk',
- 'local_table', 'remote_table', 'remote_table_2'
+ "comment_test",
+ "noncol_idx_test_pk",
+ "noncol_idx_test_nopk",
+ "local_table",
+ "remote_table",
+ "remote_table_2",
]
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
- if table_type == 'view':
+ if table_type == "view":
table_names = insp.get_view_names(schema)
table_names.sort()
- answer = ['email_addresses_v', 'users_v']
+ answer = ["email_addresses_v", "users_v"]
eq_(sorted(table_names), answer)
else:
table_names = [
- t for t in insp.get_table_names(
- schema,
- order_by=order_by) if t not in _ignore_tables]
+ t
+ for t in insp.get_table_names(schema, order_by=order_by)
+ if t not in _ignore_tables
+ ]
- if order_by == 'foreign_key':
- answer = ['users', 'email_addresses', 'dingalings']
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
eq_(table_names, answer)
else:
- answer = ['dingalings', 'email_addresses', 'users']
+ answer = ["dingalings", "email_addresses", "users"]
eq_(sorted(table_names), answer)
@testing.requires.temp_table_names
def test_get_temp_table_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_table_names()
- eq_(sorted(temp_table_names), ['user_tmp'])
+ eq_(sorted(temp_table_names), ["user_tmp"])
@testing.requires.view_reflection
@testing.requires.temp_table_names
@@ -295,7 +328,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
def test_get_temp_view_names(self):
insp = inspect(self.bind)
temp_table_names = insp.get_temp_view_names()
- eq_(sorted(temp_table_names), ['user_tmp_v'])
+ eq_(sorted(temp_table_names), ["user_tmp_v"])
@testing.requires.table_reflection
def test_get_table_names(self):
@@ -304,7 +337,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.requires.foreign_key_constraint_reflection
def test_get_table_names_fks(self):
- self._test_get_table_names(order_by='foreign_key')
+ self._test_get_table_names(order_by="foreign_key")
@testing.requires.comment_reflection
def test_get_comments(self):
@@ -320,26 +353,24 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(
insp.get_table_comment("comment_test", schema=schema),
- {"text": r"""the test % ' " \ table comment"""}
+ {"text": r"""the test % ' " \ table comment"""},
)
- eq_(
- insp.get_table_comment("users", schema=schema),
- {"text": None}
- )
+ eq_(insp.get_table_comment("users", schema=schema), {"text": None})
eq_(
[
- {"name": rec['name'], "comment": rec['comment']}
- for rec in
- insp.get_columns("comment_test", schema=schema)
+ {"name": rec["name"], "comment": rec["comment"]}
+ for rec in insp.get_columns("comment_test", schema=schema)
],
[
- {'comment': 'id comment', 'name': 'id'},
- {'comment': 'data % comment', 'name': 'data'},
- {'comment': r"""Comment types type speedily ' " \ '' Fun!""",
- 'name': 'd2'}
- ]
+ {"comment": "id comment", "name": "id"},
+ {"comment": "data % comment", "name": "data"},
+ {
+ "comment": r"""Comment types type speedily ' " \ '' Fun!""",
+ "name": "d2",
+ },
+ ],
)
@testing.requires.table_reflection
@@ -349,30 +380,33 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.view_column_reflection
def test_get_view_names(self):
- self._test_get_table_names(table_type='view')
+ self._test_get_table_names(table_type="view")
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_names_with_schema(self):
self._test_get_table_names(
- testing.config.test_schema, table_type='view')
+ testing.config.test_schema, table_type="view"
+ )
@testing.requires.table_reflection
@testing.requires.view_column_reflection
def test_get_tables_and_views(self):
self._test_get_table_names()
- self._test_get_table_names(table_type='view')
+ self._test_get_table_names(table_type="view")
- def _test_get_columns(self, schema=None, table_type='table'):
+ def _test_get_columns(self, schema=None, table_type="table"):
meta = MetaData(testing.db)
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
- table_names = ['users', 'email_addresses']
- if table_type == 'view':
- table_names = ['users_v', 'email_addresses_v']
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
+ table_names = ["users", "email_addresses"]
+ if table_type == "view":
+ table_names = ["users_v", "email_addresses_v"]
insp = inspect(meta.bind)
- for table_name, table in zip(table_names, (users,
- addresses)):
+ for table_name, table in zip(table_names, (users, addresses)):
schema_name = schema
cols = insp.get_columns(table_name, schema=schema_name)
self.assert_(len(cols) > 0, len(cols))
@@ -380,36 +414,46 @@ class ComponentReflectionTest(fixtures.TablesTest):
# should be in order
for i, col in enumerate(table.columns):
- eq_(col.name, cols[i]['name'])
- ctype = cols[i]['type'].__class__
+ eq_(col.name, cols[i]["name"])
+ ctype = cols[i]["type"].__class__
ctype_def = col.type
if isinstance(ctype_def, sa.types.TypeEngine):
ctype_def = ctype_def.__class__
# Oracle returns Date for DateTime.
- if testing.against('oracle') and ctype_def \
- in (sql_types.Date, sql_types.DateTime):
+ if testing.against("oracle") and ctype_def in (
+ sql_types.Date,
+ sql_types.DateTime,
+ ):
ctype_def = sql_types.Date
# assert that the desired type and return type share
# a base within one of the generic types.
- self.assert_(len(set(ctype.__mro__).
- intersection(ctype_def.__mro__).
- intersection([
- sql_types.Integer,
- sql_types.Numeric,
- sql_types.DateTime,
- sql_types.Date,
- sql_types.Time,
- sql_types.String,
- sql_types._Binary,
- ])) > 0, '%s(%s), %s(%s)' %
- (col.name, col.type, cols[i]['name'], ctype))
+ self.assert_(
+ len(
+ set(ctype.__mro__)
+ .intersection(ctype_def.__mro__)
+ .intersection(
+ [
+ sql_types.Integer,
+ sql_types.Numeric,
+ sql_types.DateTime,
+ sql_types.Date,
+ sql_types.Time,
+ sql_types.String,
+ sql_types._Binary,
+ ]
+ )
+ )
+ > 0,
+ "%s(%s), %s(%s)"
+ % (col.name, col.type, cols[i]["name"], ctype),
+ )
if not col.primary_key:
- assert cols[i]['default'] is None
+ assert cols[i]["default"] is None
@testing.requires.table_reflection
def test_get_columns(self):
@@ -417,24 +461,20 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.provide_metadata
def _type_round_trip(self, *types):
- t = Table('t', self.metadata,
- *[
- Column('t%d' % i, type_)
- for i, type_ in enumerate(types)
- ]
- )
+ t = Table(
+ "t",
+ self.metadata,
+ *[Column("t%d" % i, type_) for i, type_ in enumerate(types)]
+ )
t.create()
return [
- c['type'] for c in
- inspect(self.metadata.bind).get_columns('t')
+ c["type"] for c in inspect(self.metadata.bind).get_columns("t")
]
@testing.requires.table_reflection
def test_numeric_reflection(self):
- for typ in self._type_round_trip(
- sql_types.Numeric(18, 5),
- ):
+ for typ in self._type_round_trip(sql_types.Numeric(18, 5)):
assert isinstance(typ, sql_types.Numeric)
eq_(typ.precision, 18)
eq_(typ.scale, 5)
@@ -448,16 +488,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.table_reflection
@testing.provide_metadata
def test_nullable_reflection(self):
- t = Table('t', self.metadata,
- Column('a', Integer, nullable=True),
- Column('b', Integer, nullable=False))
+ t = Table(
+ "t",
+ self.metadata,
+ Column("a", Integer, nullable=True),
+ Column("b", Integer, nullable=False),
+ )
t.create()
eq_(
dict(
- (col['name'], col['nullable'])
- for col in inspect(self.metadata.bind).get_columns('t')
+ (col["name"], col["nullable"])
+ for col in inspect(self.metadata.bind).get_columns("t")
),
- {"a": True, "b": False}
+ {"a": True, "b": False},
)
@testing.requires.table_reflection
@@ -470,32 +513,30 @@ class ComponentReflectionTest(fixtures.TablesTest):
meta = MetaData(self.bind)
user_tmp = self.tables.user_tmp
insp = inspect(meta.bind)
- cols = insp.get_columns('user_tmp')
+ cols = insp.get_columns("user_tmp")
self.assert_(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
- eq_(col.name, cols[i]['name'])
+ eq_(col.name, cols[i]["name"])
@testing.requires.temp_table_reflection
@testing.requires.view_column_reflection
@testing.requires.temporary_views
def test_get_temp_view_columns(self):
insp = inspect(self.bind)
- cols = insp.get_columns('user_tmp_v')
- eq_(
- [col['name'] for col in cols],
- ['id', 'name', 'foo']
- )
+ cols = insp.get_columns("user_tmp_v")
+ eq_([col["name"] for col in cols], ["id", "name", "foo"])
@testing.requires.view_column_reflection
def test_get_view_columns(self):
- self._test_get_columns(table_type='view')
+ self._test_get_columns(table_type="view")
@testing.requires.view_column_reflection
@testing.requires.schemas
def test_get_view_columns_with_schema(self):
self._test_get_columns(
- schema=testing.config.test_schema, table_type='view')
+ schema=testing.config.test_schema, table_type="view"
+ )
@testing.provide_metadata
def _test_get_pk_constraint(self, schema=None):
@@ -504,15 +545,15 @@ class ComponentReflectionTest(fixtures.TablesTest):
insp = inspect(meta.bind)
users_cons = insp.get_pk_constraint(users.name, schema=schema)
- users_pkeys = users_cons['constrained_columns']
- eq_(users_pkeys, ['user_id'])
+ users_pkeys = users_cons["constrained_columns"]
+ eq_(users_pkeys, ["user_id"])
addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
- addr_pkeys = addr_cons['constrained_columns']
- eq_(addr_pkeys, ['address_id'])
+ addr_pkeys = addr_cons["constrained_columns"]
+ eq_(addr_pkeys, ["address_id"])
with testing.requires.reflects_pk_names.fail_if():
- eq_(addr_cons['name'], 'email_ad_pk')
+ eq_(addr_cons["name"], "email_ad_pk")
@testing.requires.primary_key_constraint_reflection
def test_get_pk_constraint(self):
@@ -534,44 +575,46 @@ class ComponentReflectionTest(fixtures.TablesTest):
sa_exc.SADeprecationWarning,
"Call to deprecated method get_primary_keys."
" Use get_pk_constraint instead.",
- insp.get_primary_keys, users.name
+ insp.get_primary_keys,
+ users.name,
)
@testing.provide_metadata
def _test_get_foreign_keys(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
expected_schema = schema
# users
if testing.requires.self_referential_foreign_keys.enabled:
- users_fkeys = insp.get_foreign_keys(users.name,
- schema=schema)
+ users_fkeys = insp.get_foreign_keys(users.name, schema=schema)
fkey1 = users_fkeys[0]
with testing.requires.named_constraints.fail_if():
- eq_(fkey1['name'], "user_id_fk")
+ eq_(fkey1["name"], "user_id_fk")
- eq_(fkey1['referred_schema'], expected_schema)
- eq_(fkey1['referred_table'], users.name)
- eq_(fkey1['referred_columns'], ['user_id', ])
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
if testing.requires.self_referential_foreign_keys.enabled:
- eq_(fkey1['constrained_columns'], ['parent_user_id'])
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
# addresses
- addr_fkeys = insp.get_foreign_keys(addresses.name,
- schema=schema)
+ addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
fkey1 = addr_fkeys[0]
with testing.requires.implicitly_named_constraints.fail_if():
- self.assert_(fkey1['name'] is not None)
+ self.assert_(fkey1["name"] is not None)
- eq_(fkey1['referred_schema'], expected_schema)
- eq_(fkey1['referred_table'], users.name)
- eq_(fkey1['referred_columns'], ['user_id', ])
- eq_(fkey1['constrained_columns'], ['remote_user_id'])
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ eq_(fkey1["constrained_columns"], ["remote_user_id"])
@testing.requires.foreign_key_constraint_reflection
def test_get_foreign_keys(self):
@@ -586,9 +629,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schemas
def test_get_inter_schema_foreign_keys(self):
local_table, remote_table, remote_table_2 = self.tables(
- '%s.local_table' % testing.db.dialect.default_schema_name,
- '%s.remote_table' % testing.config.test_schema,
- '%s.remote_table_2' % testing.config.test_schema
+ "%s.local_table" % testing.db.dialect.default_schema_name,
+ "%s.remote_table" % testing.config.test_schema,
+ "%s.remote_table_2" % testing.config.test_schema,
)
insp = inspect(config.db)
@@ -597,25 +640,25 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(len(local_fkeys), 1)
fkey1 = local_fkeys[0]
- eq_(fkey1['referred_schema'], testing.config.test_schema)
- eq_(fkey1['referred_table'], remote_table_2.name)
- eq_(fkey1['referred_columns'], ['id', ])
- eq_(fkey1['constrained_columns'], ['remote_id'])
+ eq_(fkey1["referred_schema"], testing.config.test_schema)
+ eq_(fkey1["referred_table"], remote_table_2.name)
+ eq_(fkey1["referred_columns"], ["id"])
+ eq_(fkey1["constrained_columns"], ["remote_id"])
remote_fkeys = insp.get_foreign_keys(
- remote_table.name, schema=testing.config.test_schema)
+ remote_table.name, schema=testing.config.test_schema
+ )
eq_(len(remote_fkeys), 1)
fkey2 = remote_fkeys[0]
- assert fkey2['referred_schema'] in (
+ assert fkey2["referred_schema"] in (
None,
- testing.db.dialect.default_schema_name
+ testing.db.dialect.default_schema_name,
)
- eq_(fkey2['referred_table'], local_table.name)
- eq_(fkey2['referred_columns'], ['id', ])
- eq_(fkey2['constrained_columns'], ['local_id'])
-
+ eq_(fkey2["referred_table"], local_table.name)
+ eq_(fkey2["referred_columns"], ["id"])
+ eq_(fkey2["constrained_columns"], ["local_id"])
@testing.requires.foreign_key_constraint_option_reflection_ondelete
def test_get_foreign_key_options_ondelete(self):
@@ -630,26 +673,32 @@ class ComponentReflectionTest(fixtures.TablesTest):
meta = self.metadata
Table(
- 'x', meta,
- Column('id', Integer, primary_key=True),
- test_needs_fk=True
- )
-
- Table('table', meta,
- Column('id', Integer, primary_key=True),
- Column('x_id', Integer, sa.ForeignKey('x.id', name='xid')),
- Column('test', String(10)),
- test_needs_fk=True)
-
- Table('user', meta,
- Column('id', Integer, primary_key=True),
- Column('name', String(50), nullable=False),
- Column('tid', Integer),
- sa.ForeignKeyConstraint(
- ['tid'], ['table.id'],
- name='myfk',
- **options),
- test_needs_fk=True)
+ "x",
+ meta,
+ Column("id", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "table",
+ meta,
+ Column("id", Integer, primary_key=True),
+ Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")),
+ Column("test", String(10)),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "user",
+ meta,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ sa.ForeignKeyConstraint(
+ ["tid"], ["table.id"], name="myfk", **options
+ ),
+ test_needs_fk=True,
+ )
meta.create_all()
@@ -657,49 +706,44 @@ class ComponentReflectionTest(fixtures.TablesTest):
# test 'options' is always present for a backend
# that can reflect these, since alembic looks for this
- opts = insp.get_foreign_keys('table')[0]['options']
+ opts = insp.get_foreign_keys("table")[0]["options"]
- eq_(
- dict(
- (k, opts[k])
- for k in opts if opts[k]
- ),
- {}
- )
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), {})
- opts = insp.get_foreign_keys('user')[0]['options']
- eq_(
- dict(
- (k, opts[k])
- for k in opts if opts[k]
- ),
- options
- )
+ opts = insp.get_foreign_keys("user")[0]["options"]
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), options)
def _assert_insp_indexes(self, indexes, expected_indexes):
- index_names = [d['name'] for d in indexes]
+ index_names = [d["name"] for d in indexes]
for e_index in expected_indexes:
- assert e_index['name'] in index_names
- index = indexes[index_names.index(e_index['name'])]
+ assert e_index["name"] in index_names
+ index = indexes[index_names.index(e_index["name"])]
for key in e_index:
eq_(e_index[key], index[key])
@testing.provide_metadata
def _test_get_indexes(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
# The database may decide to create indexes for foreign keys, etc.
# so there may be more indexes than expected.
insp = inspect(meta.bind)
- indexes = insp.get_indexes('users', schema=schema)
+ indexes = insp.get_indexes("users", schema=schema)
expected_indexes = [
- {'unique': False,
- 'column_names': ['test1', 'test2'],
- 'name': 'users_t_idx'},
- {'unique': False,
- 'column_names': ['user_id', 'test2', 'test1'],
- 'name': 'users_all_idx'}
+ {
+ "unique": False,
+ "column_names": ["test1", "test2"],
+ "name": "users_t_idx",
+ },
+ {
+ "unique": False,
+ "column_names": ["user_id", "test2", "test1"],
+ "name": "users_all_idx",
+ },
]
self._assert_insp_indexes(indexes, expected_indexes)
@@ -721,10 +765,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
# reflecting an index that has "x DESC" in it as the column.
# the DB may or may not give us "x", but make sure we get the index
# back, it has a name, it's connected to the table.
- expected_indexes = [
- {'unique': False,
- 'name': ixname}
- ]
+ expected_indexes = [{"unique": False, "name": ixname}]
self._assert_insp_indexes(indexes, expected_indexes)
t = Table(tname, meta, autoload_with=meta.bind)
@@ -748,24 +789,30 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.unique_constraint_reflection
def test_get_temp_table_unique_constraints(self):
insp = inspect(self.bind)
- reflected = insp.get_unique_constraints('user_tmp')
+ reflected = insp.get_unique_constraints("user_tmp")
for refl in reflected:
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
- refl.pop('duplicates_index', None)
- eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}])
+ refl.pop("duplicates_index", None)
+ eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}])
@testing.requires.temp_table_reflection
def test_get_temp_table_indexes(self):
insp = inspect(self.bind)
- indexes = insp.get_indexes('user_tmp')
+ indexes = insp.get_indexes("user_tmp")
for ind in indexes:
- ind.pop('dialect_options', None)
+ ind.pop("dialect_options", None)
eq_(
# TODO: we need to add better filtering for indexes/uq constraints
# that are doubled up
- [idx for idx in indexes if idx['name'] == 'user_tmp_ix'],
- [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}]
+ [idx for idx in indexes if idx["name"] == "user_tmp_ix"],
+ [
+ {
+ "unique": False,
+ "column_names": ["foo"],
+ "name": "user_tmp_ix",
+ }
+ ],
)
@testing.requires.unique_constraint_reflection
@@ -783,36 +830,37 @@ class ComponentReflectionTest(fixtures.TablesTest):
# CREATE TABLE?
uniques = sorted(
[
- {'name': 'unique_a', 'column_names': ['a']},
- {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
- {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']},
- {'name': 'unique_asc_key', 'column_names': ['asc', 'key']},
- {'name': 'i.have.dots', 'column_names': ['b']},
- {'name': 'i have spaces', 'column_names': ['c']},
+ {"name": "unique_a", "column_names": ["a"]},
+ {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]},
+ {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]},
+ {"name": "unique_asc_key", "column_names": ["asc", "key"]},
+ {"name": "i.have.dots", "column_names": ["b"]},
+ {"name": "i have spaces", "column_names": ["c"]},
],
- key=operator.itemgetter('name')
+ key=operator.itemgetter("name"),
)
orig_meta = self.metadata
table = Table(
- 'testtbl', orig_meta,
- Column('a', sa.String(20)),
- Column('b', sa.String(30)),
- Column('c', sa.Integer),
+ "testtbl",
+ orig_meta,
+ Column("a", sa.String(20)),
+ Column("b", sa.String(30)),
+ Column("c", sa.Integer),
# reserved identifiers
- Column('asc', sa.String(30)),
- Column('key', sa.String(30)),
- schema=schema
+ Column("asc", sa.String(30)),
+ Column("key", sa.String(30)),
+ schema=schema,
)
for uc in uniques:
table.append_constraint(
- sa.UniqueConstraint(*uc['column_names'], name=uc['name'])
+ sa.UniqueConstraint(*uc["column_names"], name=uc["name"])
)
orig_meta.create_all()
inspector = inspect(orig_meta.bind)
reflected = sorted(
- inspector.get_unique_constraints('testtbl', schema=schema),
- key=operator.itemgetter('name')
+ inspector.get_unique_constraints("testtbl", schema=schema),
+ key=operator.itemgetter("name"),
)
names_that_duplicate_index = set()
@@ -820,25 +868,31 @@ class ComponentReflectionTest(fixtures.TablesTest):
for orig, refl in zip(uniques, reflected):
# Different dialects handle duplicate index and constraints
# differently, so ignore this flag
- dupe = refl.pop('duplicates_index', None)
+ dupe = refl.pop("duplicates_index", None)
if dupe:
names_that_duplicate_index.add(dupe)
eq_(orig, refl)
reflected_metadata = MetaData()
reflected = Table(
- 'testtbl', reflected_metadata, autoload_with=orig_meta.bind,
- schema=schema)
+ "testtbl",
+ reflected_metadata,
+ autoload_with=orig_meta.bind,
+ schema=schema,
+ )
# test "deduplicates for index" logic. MySQL and Oracle
# "unique constraints" are actually unique indexes (with possible
# exception of a unique that is a dupe of another one in the case
# of Oracle). make sure # they aren't duplicated.
idx_names = set([idx.name for idx in reflected.indexes])
- uq_names = set([
- uq.name for uq in reflected.constraints
- if isinstance(uq, sa.UniqueConstraint)]).difference(
- ['unique_c_a_b'])
+ uq_names = set(
+ [
+ uq.name
+ for uq in reflected.constraints
+ if isinstance(uq, sa.UniqueConstraint)
+ ]
+ ).difference(["unique_c_a_b"])
assert not idx_names.intersection(uq_names)
if names_that_duplicate_index:
@@ -858,47 +912,52 @@ class ComponentReflectionTest(fixtures.TablesTest):
def _test_get_check_constraints(self, schema=None):
orig_meta = self.metadata
Table(
- 'sa_cc', orig_meta,
- Column('a', Integer()),
- sa.CheckConstraint('a > 1 AND a < 5', name='cc1'),
- sa.CheckConstraint('a = 1 OR (a > 2 AND a < 5)', name='cc2'),
- schema=schema
+ "sa_cc",
+ orig_meta,
+ Column("a", Integer()),
+ sa.CheckConstraint("a > 1 AND a < 5", name="cc1"),
+ sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"),
+ schema=schema,
)
orig_meta.create_all()
inspector = inspect(orig_meta.bind)
reflected = sorted(
- inspector.get_check_constraints('sa_cc', schema=schema),
- key=operator.itemgetter('name')
+ inspector.get_check_constraints("sa_cc", schema=schema),
+ key=operator.itemgetter("name"),
)
# trying to minimize effect of quoting, parenthesis, etc.
# may need to add more to this as new dialects get CHECK
# constraint reflection support
def normalize(sqltext):
- return " ".join(re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I))
+ return " ".join(
+ re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)
+ )
reflected = [
- {"name": item["name"],
- "sqltext": normalize(item["sqltext"])}
+ {"name": item["name"], "sqltext": normalize(item["sqltext"])}
for item in reflected
]
eq_(
reflected,
[
- {'name': 'cc1', 'sqltext': 'a > 1 and a < 5'},
- {'name': 'cc2', 'sqltext': 'a = 1 or a > 2 and a < 5'}
- ]
+ {"name": "cc1", "sqltext": "a > 1 and a < 5"},
+ {"name": "cc2", "sqltext": "a = 1 or a > 2 and a < 5"},
+ ],
)
@testing.provide_metadata
def _test_get_view_definition(self, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
- view_name1 = 'users_v'
- view_name2 = 'email_addresses_v'
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
+ view_name1 = "users_v"
+ view_name2 = "email_addresses_v"
insp = inspect(meta.bind)
v1 = insp.get_view_definition(view_name1, schema=schema)
self.assert_(v1)
@@ -918,18 +977,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.provide_metadata
def _test_get_table_oid(self, table_name, schema=None):
meta = self.metadata
- users, addresses, dingalings = self.tables.users, \
- self.tables.email_addresses, self.tables.dingalings
+ users, addresses, dingalings = (
+ self.tables.users,
+ self.tables.email_addresses,
+ self.tables.dingalings,
+ )
insp = inspect(meta.bind)
oid = insp.get_table_oid(table_name, schema)
self.assert_(isinstance(oid, int))
def test_get_table_oid(self):
- self._test_get_table_oid('users')
+ self._test_get_table_oid("users")
@testing.requires.schemas
def test_get_table_oid_with_schema(self):
- self._test_get_table_oid('users', schema=testing.config.test_schema)
+ self._test_get_table_oid("users", schema=testing.config.test_schema)
@testing.requires.table_reflection
@testing.provide_metadata
@@ -950,49 +1012,53 @@ class ComponentReflectionTest(fixtures.TablesTest):
insp = inspect(meta.bind)
for tname, cname in [
- ('users', 'user_id'),
- ('email_addresses', 'address_id'),
- ('dingalings', 'dingaling_id'),
+ ("users", "user_id"),
+ ("email_addresses", "address_id"),
+ ("dingalings", "dingaling_id"),
]:
cols = insp.get_columns(tname)
- id_ = {c['name']: c for c in cols}[cname]
- assert id_.get('autoincrement', True)
+ id_ = {c["name"]: c for c in cols}[cname]
+ assert id_.get("autoincrement", True)
class NormalizedNameTest(fixtures.TablesTest):
- __requires__ = 'denormalized_names',
+ __requires__ = ("denormalized_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
Table(
- quoted_name('t1', quote=True), metadata,
- Column('id', Integer, primary_key=True),
+ quoted_name("t1", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
)
Table(
- quoted_name('t2', quote=True), metadata,
- Column('id', Integer, primary_key=True),
- Column('t1id', ForeignKey('t1.id'))
+ quoted_name("t2", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("t1id", ForeignKey("t1.id")),
)
def test_reflect_lowercase_forced_tables(self):
m2 = MetaData(testing.db)
- t2_ref = Table(quoted_name('t2', quote=True), m2, autoload=True)
- t1_ref = m2.tables['t1']
+ t2_ref = Table(quoted_name("t2", quote=True), m2, autoload=True)
+ t1_ref = m2.tables["t1"]
assert t2_ref.c.t1id.references(t1_ref.c.id)
m3 = MetaData(testing.db)
- m3.reflect(only=lambda name, m: name.lower() in ('t1', 't2'))
- assert m3.tables['t2'].c.t1id.references(m3.tables['t1'].c.id)
+ m3.reflect(only=lambda name, m: name.lower() in ("t1", "t2"))
+ assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id)
def test_get_table_names(self):
tablenames = [
- t for t in inspect(testing.db).get_table_names()
- if t.lower() in ("t1", "t2")]
+ t
+ for t in inspect(testing.db).get_table_names()
+ if t.lower() in ("t1", "t2")
+ ]
eq_(tablenames[0].upper(), tablenames[0].lower())
eq_(tablenames[1].upper(), tablenames[1].lower())
-__all__ = ('ComponentReflectionTest', 'HasTableTest', 'NormalizedNameTest')
+__all__ = ("ComponentReflectionTest", "HasTableTest", "NormalizedNameTest")
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
index f464d47eb..247f05cf5 100644
--- a/lib/sqlalchemy/testing/suite/test_results.py
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -15,14 +15,18 @@ class RowFetchTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('plain_pk', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
- Table('has_dates', metadata,
- Column('id', Integer, primary_key=True),
- Column('today', DateTime)
- )
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Table(
+ "has_dates",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("today", DateTime),
+ )
@classmethod
def insert_data(cls):
@@ -32,65 +36,51 @@ class RowFetchTest(fixtures.TablesTest):
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
- ]
+ ],
)
config.db.execute(
cls.tables.has_dates.insert(),
- [
- {"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}
- ]
+ [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
)
def test_via_string(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row['id'], 1
- )
- eq_(
- row['data'], "d1"
- )
+ eq_(row["id"], 1)
+ eq_(row["data"], "d1")
def test_via_int(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row[0], 1
- )
- eq_(
- row[1], "d1"
- )
+ eq_(row[0], 1)
+ eq_(row[1], "d1")
def test_via_col_object(self):
row = config.db.execute(
- self.tables.plain_pk.select().
- order_by(self.tables.plain_pk.c.id)
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
).first()
- eq_(
- row[self.tables.plain_pk.c.id], 1
- )
- eq_(
- row[self.tables.plain_pk.c.data], "d1"
- )
+ eq_(row[self.tables.plain_pk.c.id], 1)
+ eq_(row[self.tables.plain_pk.c.data], "d1")
@requirements.duplicate_names_in_cursor_description
def test_row_with_dupe_names(self):
result = config.db.execute(
- select([self.tables.plain_pk.c.data,
- self.tables.plain_pk.c.data.label('data')]).
- order_by(self.tables.plain_pk.c.id)
+ select(
+ [
+ self.tables.plain_pk.c.data,
+ self.tables.plain_pk.c.data.label("data"),
+ ]
+ ).order_by(self.tables.plain_pk.c.id)
)
row = result.first()
- eq_(result.keys(), ['data', 'data'])
- eq_(row, ('d1', 'd1'))
+ eq_(result.keys(), ["data", "data"])
+ eq_(row, ("d1", "d1"))
def test_row_w_scalar_select(self):
"""test that a scalar select as a column is returned as such
@@ -101,11 +91,11 @@ class RowFetchTest(fixtures.TablesTest):
"""
datetable = self.tables.has_dates
- s = select([datetable.alias('x').c.today]).as_scalar()
- s2 = select([datetable.c.id, s.label('somelabel')])
+ s = select([datetable.alias("x").c.today]).as_scalar()
+ s2 = select([datetable.c.id, s.label("somelabel")])
row = config.db.execute(s2).first()
- eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0))
+ eq_(row["somelabel"], datetime.datetime(2006, 5, 12, 12, 0, 0))
class PercentSchemaNamesTest(fixtures.TablesTest):
@@ -117,29 +107,31 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
"""
- __requires__ = ('percent_schema_names', )
+ __requires__ = ("percent_schema_names",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- cls.tables.percent_table = Table('percent%table', metadata,
- Column("percent%", Integer),
- Column(
- "spaces % more spaces", Integer),
- )
+ cls.tables.percent_table = Table(
+ "percent%table",
+ metadata,
+ Column("percent%", Integer),
+ Column("spaces % more spaces", Integer),
+ )
cls.tables.lightweight_percent_table = sql.table(
- 'percent%table', sql.column("percent%"),
- sql.column("spaces % more spaces")
+ "percent%table",
+ sql.column("percent%"),
+ sql.column("spaces % more spaces"),
)
def test_single_roundtrip(self):
percent_table = self.tables.percent_table
for params in [
- {'percent%': 5, 'spaces % more spaces': 12},
- {'percent%': 7, 'spaces % more spaces': 11},
- {'percent%': 9, 'spaces % more spaces': 10},
- {'percent%': 11, 'spaces % more spaces': 9}
+ {"percent%": 5, "spaces % more spaces": 12},
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
]:
config.db.execute(percent_table.insert(), params)
self._assert_table()
@@ -147,14 +139,15 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
def test_executemany_roundtrip(self):
percent_table = self.tables.percent_table
config.db.execute(
- percent_table.insert(),
- {'percent%': 5, 'spaces % more spaces': 12}
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
)
config.db.execute(
percent_table.insert(),
- [{'percent%': 7, 'spaces % more spaces': 11},
- {'percent%': 9, 'spaces % more spaces': 10},
- {'percent%': 11, 'spaces % more spaces': 9}]
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
)
self._assert_table()
@@ -163,85 +156,81 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
lightweight_percent_table = self.tables.lightweight_percent_table
for table in (
- percent_table,
- percent_table.alias(),
- lightweight_percent_table,
- lightweight_percent_table.alias()):
+ percent_table,
+ percent_table.alias(),
+ lightweight_percent_table,
+ lightweight_percent_table.alias(),
+ ):
eq_(
list(
config.db.execute(
- table.select().order_by(table.c['percent%'])
+ table.select().order_by(table.c["percent%"])
)
),
- [
- (5, 12),
- (7, 11),
- (9, 10),
- (11, 9)
- ]
+ [(5, 12), (7, 11), (9, 10), (11, 9)],
)
eq_(
list(
config.db.execute(
- table.select().
- where(table.c['spaces % more spaces'].in_([9, 10])).
- order_by(table.c['percent%']),
+ table.select()
+ .where(table.c["spaces % more spaces"].in_([9, 10]))
+ .order_by(table.c["percent%"])
)
),
- [
- (9, 10),
- (11, 9)
- ]
+ [(9, 10), (11, 9)],
)
- row = config.db.execute(table.select().
- order_by(table.c['percent%'])).first()
- eq_(row['percent%'], 5)
- eq_(row['spaces % more spaces'], 12)
+ row = config.db.execute(
+ table.select().order_by(table.c["percent%"])
+ ).first()
+ eq_(row["percent%"], 5)
+ eq_(row["spaces % more spaces"], 12)
- eq_(row[table.c['percent%']], 5)
- eq_(row[table.c['spaces % more spaces']], 12)
+ eq_(row[table.c["percent%"]], 5)
+ eq_(row[table.c["spaces % more spaces"]], 12)
config.db.execute(
percent_table.update().values(
- {percent_table.c['spaces % more spaces']: 15}
+ {percent_table.c["spaces % more spaces"]: 15}
)
)
eq_(
list(
config.db.execute(
- percent_table.
- select().
- order_by(percent_table.c['percent%'])
+ percent_table.select().order_by(
+ percent_table.c["percent%"]
+ )
)
),
- [(5, 15), (7, 15), (9, 15), (11, 15)]
+ [(5, 15), (7, 15), (9, 15), (11, 15)],
)
-class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
+class ServerSideCursorsTest(
+ fixtures.TestBase, testing.AssertsExecutionResults
+):
- __requires__ = ('server_side_cursors', )
+ __requires__ = ("server_side_cursors",)
__backend__ = True
def _is_server_side(self, cursor):
if self.engine.dialect.driver == "psycopg2":
return cursor.name
- elif self.engine.dialect.driver == 'pymysql':
- sscursor = __import__('pymysql.cursors').cursors.SSCursor
+ elif self.engine.dialect.driver == "pymysql":
+ sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
elif self.engine.dialect.driver == "mysqldb":
- sscursor = __import__('MySQLdb.cursors').cursors.SSCursor
+ sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
else:
return False
def _fixture(self, server_side_cursors):
self.engine = engines.testing_engine(
- options={'server_side_cursors': server_side_cursors}
+ options={"server_side_cursors": server_side_cursors}
)
return self.engine
@@ -251,12 +240,12 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_global_string(self):
engine = self._fixture(True)
- result = engine.execute('select 1')
+ result = engine.execute("select 1")
assert self._is_server_side(result.cursor)
def test_global_text(self):
engine = self._fixture(True)
- result = engine.execute(text('select 1'))
+ result = engine.execute(text("select 1"))
assert self._is_server_side(result.cursor)
def test_global_expr(self):
@@ -266,7 +255,7 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_global_off_explicit(self):
engine = self._fixture(False)
- result = engine.execute(text('select 1'))
+ result = engine.execute(text("select 1"))
# It should be off globally ...
@@ -286,10 +275,11 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
engine = self._fixture(False)
# and this one
- result = \
- engine.connect().execution_options(stream_results=True).\
- execute('select 1'
- )
+ result = (
+ engine.connect()
+ .execution_options(stream_results=True)
+ .execute("select 1")
+ )
assert self._is_server_side(result.cursor)
def test_stmt_enabled_conn_option_disabled(self):
@@ -298,9 +288,9 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
s = select([1]).execution_options(stream_results=True)
# not this one
- result = \
- engine.connect().execution_options(stream_results=False).\
- execute(s)
+ result = (
+ engine.connect().execution_options(stream_results=False).execute(s)
+ )
assert not self._is_server_side(result.cursor)
def test_stmt_option_disabled(self):
@@ -329,18 +319,18 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
def test_for_update_string(self):
engine = self._fixture(True)
- result = engine.execute('SELECT 1 FOR UPDATE')
+ result = engine.execute("SELECT 1 FOR UPDATE")
assert self._is_server_side(result.cursor)
def test_text_no_ss(self):
engine = self._fixture(False)
- s = text('select 42')
+ s = text("select 42")
result = engine.execute(s)
assert not self._is_server_side(result.cursor)
def test_text_ss_option(self):
engine = self._fixture(False)
- s = text('select 42').execution_options(stream_results=True)
+ s = text("select 42").execution_options(stream_results=True)
result = engine.execute(s)
assert self._is_server_side(result.cursor)
@@ -349,19 +339,25 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
md = self.metadata
engine = self._fixture(True)
- test_table = Table('test_table', md,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)))
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
test_table.create(checkfirst=True)
- test_table.insert().execute(data='data1')
- test_table.insert().execute(data='data2')
- eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, 'data1'), (2, 'data2')])
- test_table.update().where(
- test_table.c.id == 2).values(
- data=test_table.c.data +
- ' updated').execute()
- eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(),
- [(1, 'data1'), (2, 'data2 updated')])
+ test_table.insert().execute(data="data1")
+ test_table.insert().execute(data="data2")
+ eq_(
+ test_table.select().order_by(test_table.c.id).execute().fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ test_table.update().where(test_table.c.id == 2).values(
+ data=test_table.c.data + " updated"
+ ).execute()
+ eq_(
+ test_table.select().order_by(test_table.c.id).execute().fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
test_table.delete().execute()
- eq_(select([func.count('*')]).select_from(test_table).scalar(), 0)
+ eq_(select([func.count("*")]).select_from(test_table).scalar(), 0)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index 73ce02492..032b68eb6 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -16,10 +16,12 @@ class CollateTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(100))
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
@classmethod
def insert_data(cls):
@@ -28,26 +30,21 @@ class CollateTest(fixtures.TablesTest):
[
{"id": 1, "data": "collate data1"},
{"id": 2, "data": "collate data2"},
- ]
+ ],
)
def _assert_result(self, select, result):
- eq_(
- config.db.execute(select).fetchall(),
- result
- )
+ eq_(config.db.execute(select).fetchall(), result)
@testing.requires.order_by_collation
def test_collate_order_by(self):
collation = testing.requires.get_order_by_collation(testing.config)
self._assert_result(
- select([self.tables.some_table]).
- order_by(self.tables.some_table.c.data.collate(collation).asc()),
- [
- (1, "collate data1"),
- (2, "collate data2"),
- ]
+ select([self.tables.some_table]).order_by(
+ self.tables.some_table.c.data.collate(collation).asc()
+ ),
+ [(1, "collate data1"), (2, "collate data2")],
)
@@ -59,17 +56,20 @@ class OrderByLabelTest(fixtures.TablesTest):
setting.
"""
+
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer),
- Column('q', String(50)),
- Column('p', String(50))
- )
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", String(50)),
+ Column("p", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -79,65 +79,55 @@ class OrderByLabelTest(fixtures.TablesTest):
{"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
{"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
{"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
- ]
+ ],
)
def _assert_result(self, select, result):
- eq_(
- config.db.execute(select).fetchall(),
- result
- )
+ eq_(config.db.execute(select).fetchall(), result)
def test_plain(self):
table = self.tables.some_table
- lx = table.c.x.label('lx')
- self._assert_result(
- select([lx]).order_by(lx),
- [(1, ), (2, ), (3, )]
- )
+ lx = table.c.x.label("lx")
+ self._assert_result(select([lx]).order_by(lx), [(1,), (2,), (3,)])
def test_composed_int(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
- self._assert_result(
- select([lx]).order_by(lx),
- [(3, ), (5, ), (7, )]
- )
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select([lx]).order_by(lx), [(3,), (5,), (7,)])
def test_composed_multiple(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
- ly = (func.lower(table.c.q) + table.c.p).label('ly')
+ lx = (table.c.x + table.c.y).label("lx")
+ ly = (func.lower(table.c.q) + table.c.p).label("ly")
self._assert_result(
select([lx, ly]).order_by(lx, ly.desc()),
- [(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))]
+ [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))],
)
def test_plain_desc(self):
table = self.tables.some_table
- lx = table.c.x.label('lx')
+ lx = table.c.x.label("lx")
self._assert_result(
- select([lx]).order_by(lx.desc()),
- [(3, ), (2, ), (1, )]
+ select([lx]).order_by(lx.desc()), [(3,), (2,), (1,)]
)
def test_composed_int_desc(self):
table = self.tables.some_table
- lx = (table.c.x + table.c.y).label('lx')
+ lx = (table.c.x + table.c.y).label("lx")
self._assert_result(
- select([lx]).order_by(lx.desc()),
- [(7, ), (5, ), (3, )]
+ select([lx]).order_by(lx.desc()), [(7,), (5,), (3,)]
)
@testing.requires.group_by_complex_expression
def test_group_by_composed(self):
table = self.tables.some_table
- expr = (table.c.x + table.c.y).label('lx')
- stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr)
- self._assert_result(
- stmt,
- [(1, 3), (1, 5), (1, 7)]
+ expr = (table.c.x + table.c.y).label("lx")
+ stmt = (
+ select([func.count(table.c.id), expr])
+ .group_by(expr)
+ .order_by(expr)
)
+ self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)])
class LimitOffsetTest(fixtures.TablesTest):
@@ -145,10 +135,13 @@ class LimitOffsetTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -159,20 +152,17 @@ class LimitOffsetTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_simple_limit(self):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2),
- [(1, 1, 2), (2, 2, 3)]
+ [(1, 1, 2), (2, 2, 3)],
)
@testing.requires.offset
@@ -180,7 +170,7 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).offset(2),
- [(3, 3, 4), (4, 4, 5)]
+ [(3, 3, 4), (4, 4, 5)],
)
@testing.requires.offset
@@ -188,7 +178,7 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
self._assert_result(
select([table]).order_by(table.c.id).limit(2).offset(1),
- [(2, 2, 3), (3, 3, 4)]
+ [(2, 2, 3), (3, 3, 4)],
)
@testing.requires.offset
@@ -198,41 +188,40 @@ class LimitOffsetTest(fixtures.TablesTest):
table = self.tables.some_table
stmt = select([table]).order_by(table.c.id).limit(2).offset(1)
sql = stmt.compile(
- dialect=config.db.dialect,
- compile_kwargs={"literal_binds": True})
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
sql = str(sql)
- self._assert_result(
- sql,
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(sql, [(2, 2, 3), (3, 3, 4)])
@testing.requires.bound_limit_offset
def test_bound_limit(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).limit(bindparam('l')),
+ select([table]).order_by(table.c.id).limit(bindparam("l")),
[(1, 1, 2), (2, 2, 3)],
- params={"l": 2}
+ params={"l": 2},
)
@testing.requires.bound_limit_offset
def test_bound_offset(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).offset(bindparam('o')),
+ select([table]).order_by(table.c.id).offset(bindparam("o")),
[(3, 3, 4), (4, 4, 5)],
- params={"o": 2}
+ params={"o": 2},
)
@testing.requires.bound_limit_offset
def test_bound_limit_offset(self):
table = self.tables.some_table
self._assert_result(
- select([table]).order_by(table.c.id).
- limit(bindparam("l")).offset(bindparam("o")),
+ select([table])
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
[(2, 2, 3), (3, 3, 4)],
- params={"l": 2, "o": 1}
+ params={"l": 2, "o": 1},
)
@@ -241,10 +230,13 @@ class CompoundSelectTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
@classmethod
def insert_data(cls):
@@ -255,14 +247,11 @@ class CompoundSelectTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3},
{"id": 3, "x": 3, "y": 4},
{"id": 4, "x": 4, "y": 5},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_plain_union(self):
table = self.tables.some_table
@@ -270,10 +259,7 @@ class CompoundSelectTest(fixtures.TablesTest):
s2 = select([table]).where(table.c.id == 3)
u1 = union(s1, s2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
def test_select_from_plain_union(self):
table = self.tables.some_table
@@ -281,80 +267,88 @@ class CompoundSelectTest(fixtures.TablesTest):
s2 = select([table]).where(table.c.id == 3)
u1 = union(s1, s2).alias().select()
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.order_by_col_from_union
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id)
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ )
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.parens_in_union_contained_select_wo_limit_offset
def test_order_by_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- order_by(table.c.id)
+ s1 = select([table]).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select([table]).where(table.c.id == 3).order_by(table.c.id)
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
def test_distinct_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- distinct()
- s2 = select([table]).where(table.c.id == 3).\
- distinct()
+ s1 = select([table]).where(table.c.id == 2).distinct()
+ s2 = select([table]).where(table.c.id == 3).distinct()
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
@testing.requires.parens_in_union_contained_select_w_limit_offset
def test_limit_offset_in_unions_from_alias(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id)
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id)
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ )
# this necessarily has double parens
u1 = union(s1, s2).alias()
self._assert_result(
- u1.select().limit(2).order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
+ u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
)
def test_limit_offset_aliased_selectable_in_unions(self):
table = self.tables.some_table
- s1 = select([table]).where(table.c.id == 2).\
- limit(1).order_by(table.c.id).alias().select()
- s2 = select([table]).where(table.c.id == 3).\
- limit(1).order_by(table.c.id).alias().select()
+ s1 = (
+ select([table])
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select([table])
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
u1 = union(s1, s2).limit(2)
- self._assert_result(
- u1.order_by(u1.c.id),
- [(2, 2, 3), (3, 3, 4)]
- )
+ self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
class ExpandingBoundInTest(fixtures.TablesTest):
@@ -362,11 +356,14 @@ class ExpandingBoundInTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('x', Integer),
- Column('y', Integer),
- Column('z', String(50)))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -377,178 +374,184 @@ class ExpandingBoundInTest(fixtures.TablesTest):
{"id": 2, "x": 2, "y": 3, "z": "z2"},
{"id": 3, "x": 3, "y": 4, "z": "z3"},
{"id": 4, "x": 4, "y": 5, "z": "z4"},
- ]
+ ],
)
def _assert_result(self, select, result, params=()):
- eq_(
- config.db.execute(select, params).fetchall(),
- result
- )
+ eq_(config.db.execute(select, params).fetchall(), result)
def test_multiple_empty_sets(self):
# test that any anonymous aliasing used by the dialect
# is fine with duplicates
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).where(
- table.c.y.in_(bindparam('p', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": [], "p": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .where(table.c.y.in_(bindparam("p", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": [], "p": []})
+
@testing.requires.tuple_in
def test_empty_heterogeneous_tuples(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.z).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
@testing.requires.tuple_in
def test_empty_homogeneous_tuples(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_bound_in_scalar(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(2, ), (3, ), (4, )],
- params={"q": [2, 3, 4]},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
+
@testing.requires.tuple_in
def test_bound_in_two_tuple(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.y).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
+ )
self._assert_result(
- stmt,
- [(2, ), (3, ), (4, )],
- params={"q": [(2, 3), (3, 4), (4, 5)]},
+ stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
)
@testing.requires.tuple_in
def test_bound_in_heterogeneous_two_tuple(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- tuple_(table.c.x, table.c.z).in_(
- bindparam('q', expanding=True))).order_by(table.c.id)
+ stmt = (
+ select([table.c.id])
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True)
+ )
+ )
+ .order_by(table.c.id)
+ )
self._assert_result(
stmt,
- [(2, ), (3, ), (4, )],
+ [(2,), (3,), (4,)],
params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
)
def test_empty_set_against_integer(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_empty_set_against_integer_negation(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.x.notin_(bindparam('q', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(1, ), (2, ), (3, ), (4, )],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.x.notin_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
def test_empty_set_against_string(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.z.in_(bindparam('q', expanding=True))).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.z.in_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [], params={"q": []})
+
def test_empty_set_against_string_negation(self):
table = self.tables.some_table
- stmt = select([table.c.id]).where(
- table.c.z.notin_(bindparam('q', expanding=True))
- ).order_by(table.c.id)
-
- self._assert_result(
- stmt,
- [(1, ), (2, ), (3, ), (4, )],
- params={"q": []},
+ stmt = (
+ select([table.c.id])
+ .where(table.c.z.notin_(bindparam("q", expanding=True)))
+ .order_by(table.c.id)
)
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
def test_null_in_empty_set_is_false(self):
- stmt = select([
- case(
- [
- (
- null().in_(bindparam('foo', value=(), expanding=True)),
- true()
- )
- ],
- else_=false()
- )
- ])
- in_(
- config.db.execute(stmt).fetchone()[0],
- (False, 0)
+ stmt = select(
+ [
+ case(
+ [
+ (
+ null().in_(
+ bindparam("foo", value=(), expanding=True)
+ ),
+ true(),
+ )
+ ],
+ else_=false(),
+ )
+ ]
)
+ in_(config.db.execute(stmt).fetchone()[0], (False, 0))
class LikeFunctionsTest(fixtures.TablesTest):
__backend__ = True
- run_inserts = 'once'
+ run_inserts = "once"
run_deletes = None
@classmethod
def define_tables(cls, metadata):
- Table("some_table", metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50)))
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -565,7 +568,7 @@ class LikeFunctionsTest(fixtures.TablesTest):
{"id": 8, "data": "ab9cdefg"},
{"id": 9, "data": "abcde#fg"},
{"id": 10, "data": "abcd9fg"},
- ]
+ ],
)
def _test(self, expr, expected):
@@ -573,8 +576,10 @@ class LikeFunctionsTest(fixtures.TablesTest):
with config.db.connect() as conn:
rows = {
- value for value, in
- conn.execute(select([some_table.c.id]).where(expr))
+ value
+ for value, in conn.execute(
+ select([some_table.c.id]).where(expr)
+ )
}
eq_(rows, expected)
@@ -591,7 +596,8 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(
col.startswith(literal_column("'ab%c'")),
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ )
def test_startswith_escape(self):
col = self.tables.some_table.c.data
@@ -608,8 +614,9 @@ class LikeFunctionsTest(fixtures.TablesTest):
def test_endswith_sqlexpr(self):
col = self.tables.some_table.c.data
- self._test(col.endswith(literal_column("'e%fg'")),
- {1, 2, 3, 4, 5, 6, 7, 8, 9})
+ self._test(
+ col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9}
+ )
def test_endswith_autoescape(self):
col = self.tables.some_table.c.data
@@ -640,5 +647,3 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
-
-
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index f1c00de6b..15a850fe9 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -9,140 +9,144 @@ from ..schema import Table, Column
class SequenceTest(fixtures.TablesTest):
- __requires__ = ('sequences',)
+ __requires__ = ("sequences",)
__backend__ = True
- run_create_tables = 'each'
+ run_create_tables = "each"
@classmethod
def define_tables(cls, metadata):
- Table('seq_pk', metadata,
- Column('id', Integer, Sequence('tab_id_seq'), primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "seq_pk",
+ metadata,
+ Column("id", Integer, Sequence("tab_id_seq"), primary_key=True),
+ Column("data", String(50)),
+ )
- Table('seq_opt_pk', metadata,
- Column('id', Integer, Sequence('tab_id_seq', optional=True),
- primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "seq_opt_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq", optional=True),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
def test_insert_roundtrip(self):
- config.db.execute(
- self.tables.seq_pk.insert(),
- data="some data"
- )
+ config.db.execute(self.tables.seq_pk.insert(), data="some data")
self._assert_round_trip(self.tables.seq_pk, config.db)
def test_insert_lastrowid(self):
- r = config.db.execute(
- self.tables.seq_pk.insert(),
- data="some data"
- )
- eq_(
- r.inserted_primary_key,
- [1]
- )
+ r = config.db.execute(self.tables.seq_pk.insert(), data="some data")
+ eq_(r.inserted_primary_key, [1])
def test_nextval_direct(self):
- r = config.db.execute(
- self.tables.seq_pk.c.id.default
- )
- eq_(
- r, 1
- )
+ r = config.db.execute(self.tables.seq_pk.c.id.default)
+ eq_(r, 1)
@requirements.sequences_optional
def test_optional_seq(self):
r = config.db.execute(
- self.tables.seq_opt_pk.insert(),
- data="some data"
- )
- eq_(
- r.inserted_primary_key,
- [1]
+ self.tables.seq_opt_pk.insert(), data="some data"
)
+ eq_(r.inserted_primary_key, [1])
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
- eq_(
- row,
- (1, "some data")
- )
+ eq_(row, (1, "some data"))
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
- __requires__ = ('sequences',)
+ __requires__ = ("sequences",)
__backend__ = True
def test_literal_binds_inline_compile(self):
table = Table(
- 'x', MetaData(),
- Column('y', Integer, Sequence('y_seq')),
- Column('q', Integer))
+ "x",
+ MetaData(),
+ Column("y", Integer, Sequence("y_seq")),
+ Column("q", Integer),
+ )
stmt = table.insert().values(q=5)
seq_nextval = testing.db.dialect.statement_compiler(
- statement=None, dialect=testing.db.dialect).visit_sequence(
- Sequence("y_seq"))
+ statement=None, dialect=testing.db.dialect
+ ).visit_sequence(Sequence("y_seq"))
self.assert_compile(
stmt,
- "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ),
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
literal_binds=True,
- dialect=testing.db.dialect)
+ dialect=testing.db.dialect,
+ )
class HasSequenceTest(fixtures.TestBase):
- __requires__ = 'sequences',
+ __requires__ = ("sequences",)
__backend__ = True
def test_has_sequence(self):
- s1 = Sequence('user_id_seq')
+ s1 = Sequence("user_id_seq")
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db,
- 'user_id_seq'), True)
+ eq_(
+ testing.db.dialect.has_sequence(testing.db, "user_id_seq"),
+ True,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_schema(self):
- s1 = Sequence('user_id_seq', schema=config.test_schema)
+ s1 = Sequence("user_id_seq", schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(
- testing.db, 'user_id_seq', schema=config.test_schema), True)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ True,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
def test_has_sequence_neg(self):
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
- False)
+ eq_(testing.db.dialect.has_sequence(testing.db, "user_id_seq"), False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self):
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
- schema=config.test_schema),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ False,
+ )
@testing.requires.schemas
def test_has_sequence_default_not_in_remote(self):
- s1 = Sequence('user_id_seq')
+ s1 = Sequence("user_id_seq")
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
- schema=config.test_schema),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(
+ testing.db, "user_id_seq", schema=config.test_schema
+ ),
+ False,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self):
- s1 = Sequence('user_id_seq', schema=config.test_schema)
+ s1 = Sequence("user_id_seq", schema=config.test_schema)
testing.db.execute(schema.CreateSequence(s1))
try:
- eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
- False)
+ eq_(
+ testing.db.dialect.has_sequence(testing.db, "user_id_seq"),
+ False,
+ )
finally:
testing.db.execute(schema.DropSequence(s1))
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 27c7bb115..6dfb80915 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -4,9 +4,24 @@ from .. import fixtures, config
from ..assertions import eq_
from ..config import requirements
from sqlalchemy import Integer, Unicode, UnicodeText, select, TIMESTAMP
-from sqlalchemy import Date, DateTime, Time, MetaData, String, \
- Text, Numeric, Float, literal, Boolean, cast, null, JSON, and_, \
- type_coerce, BigInteger
+from sqlalchemy import (
+ Date,
+ DateTime,
+ Time,
+ MetaData,
+ String,
+ Text,
+ Numeric,
+ Float,
+ literal,
+ Boolean,
+ cast,
+ null,
+ JSON,
+ and_,
+ type_coerce,
+ BigInteger,
+)
from ..schema import Table, Column
from ... import testing
import decimal
@@ -24,13 +39,17 @@ class _LiteralRoundTripFixture(object):
# into a typed column. we can then SELECT it back as its
# official type; ideally we'd be able to use CAST here
# but MySQL in particular can't CAST fully
- t = Table('t', self.metadata, Column('x', type_))
+ t = Table("t", self.metadata, Column("x", type_))
t.create()
for value in input_:
- ins = t.insert().values(x=literal(value)).compile(
- dialect=testing.db.dialect,
- compile_kwargs=dict(literal_binds=True)
+ ins = (
+ t.insert()
+ .values(x=literal(value))
+ .compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
)
testing.db.execute(ins)
@@ -42,40 +61,33 @@ class _LiteralRoundTripFixture(object):
class _UnicodeFixture(_LiteralRoundTripFixture):
- __requires__ = 'unicode_data',
+ __requires__ = ("unicode_data",)
- data = u("Alors vous imaginez ma surprise, au lever du jour, "
- "quand une drôle de petite voix m’a réveillé. Elle "
- "disait: « S’il vous plaît… dessine-moi un mouton! »")
+ data = u(
+ "Alors vous imaginez ma surprise, au lever du jour, "
+ "quand une drôle de petite voix m’a réveillé. Elle "
+ "disait: « S’il vous plaît… dessine-moi un mouton! »"
+ )
@classmethod
def define_tables(cls, metadata):
- Table('unicode_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('unicode_data', cls.datatype),
- )
+ Table(
+ "unicode_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("unicode_data", cls.datatype),
+ )
def test_round_trip(self):
unicode_table = self.tables.unicode_table
- config.db.execute(
- unicode_table.insert(),
- {
- 'unicode_data': self.data,
- }
- )
+ config.db.execute(unicode_table.insert(), {"unicode_data": self.data})
- row = config.db.execute(
- select([
- unicode_table.c.unicode_data,
- ])
- ).first()
+ row = config.db.execute(select([unicode_table.c.unicode_data])).first()
- eq_(
- row,
- (self.data, )
- )
+ eq_(row, (self.data,))
assert isinstance(row[0], util.text_type)
def test_round_trip_executemany(self):
@@ -83,44 +95,29 @@ class _UnicodeFixture(_LiteralRoundTripFixture):
config.db.execute(
unicode_table.insert(),
- [
- {
- 'unicode_data': self.data,
- }
- for i in range(3)
- ]
+ [{"unicode_data": self.data} for i in range(3)],
)
rows = config.db.execute(
- select([
- unicode_table.c.unicode_data,
- ])
+ select([unicode_table.c.unicode_data])
).fetchall()
- eq_(
- rows,
- [(self.data, ) for i in range(3)]
- )
+ eq_(rows, [(self.data,) for i in range(3)])
for row in rows:
assert isinstance(row[0], util.text_type)
def _test_empty_strings(self):
unicode_table = self.tables.unicode_table
- config.db.execute(
- unicode_table.insert(),
- {"unicode_data": u('')}
- )
- row = config.db.execute(
- select([unicode_table.c.unicode_data])
- ).first()
- eq_(row, (u(''),))
+ config.db.execute(unicode_table.insert(), {"unicode_data": u("")})
+ row = config.db.execute(select([unicode_table.c.unicode_data])).first()
+ eq_(row, (u(""),))
def test_literal(self):
self._literal_round_trip(self.datatype, [self.data], [self.data])
class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
- __requires__ = 'unicode_data',
+ __requires__ = ("unicode_data",)
__backend__ = True
datatype = Unicode(255)
@@ -131,7 +128,7 @@ class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
- __requires__ = 'unicode_data', 'text_type'
+ __requires__ = "unicode_data", "text_type"
__backend__ = True
datatype = UnicodeText()
@@ -142,54 +139,47 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
- __requires__ = 'text_type',
+ __requires__ = ("text_type",)
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('text_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('text_data', Text),
- )
+ Table(
+ "text_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("text_data", Text),
+ )
def test_text_roundtrip(self):
text_table = self.tables.text_table
- config.db.execute(
- text_table.insert(),
- {"text_data": 'some text'}
- )
- row = config.db.execute(
- select([text_table.c.text_data])
- ).first()
- eq_(row, ('some text',))
+ config.db.execute(text_table.insert(), {"text_data": "some text"})
+ row = config.db.execute(select([text_table.c.text_data])).first()
+ eq_(row, ("some text",))
def test_text_empty_strings(self):
text_table = self.tables.text_table
- config.db.execute(
- text_table.insert(),
- {"text_data": ''}
- )
- row = config.db.execute(
- select([text_table.c.text_data])
- ).first()
- eq_(row, ('',))
+ config.db.execute(text_table.insert(), {"text_data": ""})
+ row = config.db.execute(select([text_table.c.text_data])).first()
+ eq_(row, ("",))
def test_literal(self):
self._literal_round_trip(Text, ["some text"], ["some text"])
def test_literal_quoting(self):
- data = '''some 'text' hey "hi there" that's text'''
+ data = """some 'text' hey "hi there" that's text"""
self._literal_round_trip(Text, [data], [data])
def test_literal_backslashes(self):
- data = r'backslash one \ backslash two \\ end'
+ data = r"backslash one \ backslash two \\ end"
self._literal_round_trip(Text, [data], [data])
def test_literal_percentsigns(self):
- data = r'percent % signs %% percent'
+ data = r"percent % signs %% percent"
self._literal_round_trip(Text, [data], [data])
@@ -199,9 +189,7 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
@requirements.unbounded_varchar
def test_nolength_string(self):
metadata = MetaData()
- foo = Table('foo', metadata,
- Column('one', String)
- )
+ foo = Table("foo", metadata, Column("one", String))
foo.create(config.db)
foo.drop(config.db)
@@ -210,11 +198,11 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._literal_round_trip(String(40), ["some text"], ["some text"])
def test_literal_quoting(self):
- data = '''some 'text' hey "hi there" that's text'''
+ data = """some 'text' hey "hi there" that's text"""
self._literal_round_trip(String(40), [data], [data])
def test_literal_backslashes(self):
- data = r'backslash one \ backslash two \\ end'
+ data = r"backslash one \ backslash two \\ end"
self._literal_round_trip(String(40), [data], [data])
@@ -223,44 +211,32 @@ class _DateFixture(_LiteralRoundTripFixture):
@classmethod
def define_tables(cls, metadata):
- Table('date_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('date_data', cls.datatype),
- )
+ Table(
+ "date_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("date_data", cls.datatype),
+ )
def test_round_trip(self):
date_table = self.tables.date_table
- config.db.execute(
- date_table.insert(),
- {'date_data': self.data}
- )
+ config.db.execute(date_table.insert(), {"date_data": self.data})
- row = config.db.execute(
- select([
- date_table.c.date_data,
- ])
- ).first()
+ row = config.db.execute(select([date_table.c.date_data])).first()
compare = self.compare or self.data
- eq_(row,
- (compare, ))
+ eq_(row, (compare,))
assert isinstance(row[0], type(compare))
def test_null(self):
date_table = self.tables.date_table
- config.db.execute(
- date_table.insert(),
- {'date_data': None}
- )
+ config.db.execute(date_table.insert(), {"date_data": None})
- row = config.db.execute(
- select([
- date_table.c.date_data,
- ])
- ).first()
+ row = config.db.execute(select([date_table.c.date_data])).first()
eq_(row, (None,))
@testing.requires.datetime_literals
@@ -270,48 +246,49 @@ class _DateFixture(_LiteralRoundTripFixture):
class DateTimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime',
+ __requires__ = ("datetime",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime_microseconds',
+ __requires__ = ("datetime_microseconds",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'timestamp_microseconds',
+ __requires__ = ("timestamp_microseconds",)
__backend__ = True
datatype = TIMESTAMP
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
class TimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'time',
+ __requires__ = ("time",)
__backend__ = True
datatype = Time
data = datetime.time(12, 57, 18)
class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'time_microseconds',
+ __requires__ = ("time_microseconds",)
__backend__ = True
datatype = Time
data = datetime.time(12, 57, 18, 396)
class DateTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date',
+ __requires__ = ("date",)
__backend__ = True
datatype = Date
data = datetime.date(2012, 10, 15)
class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date', 'date_coerces_from_datetime'
+ __requires__ = "date", "date_coerces_from_datetime"
__backend__ = True
datatype = Date
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
@@ -319,14 +296,14 @@ class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'datetime_historic',
+ __requires__ = ("datetime_historic",)
__backend__ = True
datatype = DateTime
data = datetime.datetime(1850, 11, 10, 11, 52, 35)
class DateHistoricTest(_DateFixture, fixtures.TablesTest):
- __requires__ = 'date_historic',
+ __requires__ = ("date_historic",)
__backend__ = True
datatype = Date
data = datetime.date(1727, 4, 1)
@@ -345,26 +322,21 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
def _round_trip(self, datatype, data):
metadata = self.metadata
int_table = Table(
- 'integer_table', metadata,
- Column('id', Integer, primary_key=True,
- test_needs_autoincrement=True),
- Column('integer_data', datatype),
+ "integer_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("integer_data", datatype),
)
metadata.create_all(config.db)
- config.db.execute(
- int_table.insert(),
- {'integer_data': data}
- )
+ config.db.execute(int_table.insert(), {"integer_data": data})
- row = config.db.execute(
- select([
- int_table.c.integer_data,
- ])
- ).first()
+ row = config.db.execute(select([int_table.c.integer_data])).first()
- eq_(row, (data, ))
+ eq_(row, (data,))
if util.py3k:
assert isinstance(row[0], int)
@@ -377,12 +349,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
@testing.provide_metadata
- def _do_test(self, type_, input_, output,
- filter_=None, check_scale=False):
+ def _do_test(self, type_, input_, output, filter_=None, check_scale=False):
metadata = self.metadata
- t = Table('t', metadata, Column('x', type_))
+ t = Table("t", metadata, Column("x", type_))
t.create()
- t.insert().execute([{'x': x} for x in input_])
+ t.insert().execute([{"x": x} for x in input_])
result = {row[0] for row in t.select().execute()}
output = set(output)
@@ -391,10 +362,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
output = set(filter_(x) for x in output)
eq_(result, output)
if check_scale:
- eq_(
- [str(x) for x in result],
- [str(x) for x in output],
- )
+ eq_([str(x) for x in result], [str(x) for x in output])
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
def test_render_literal_numeric(self):
@@ -416,8 +384,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._literal_round_trip(
Float(4),
[15.7563, decimal.Decimal("15.7563")],
- [15.7563, ],
- filter_=lambda n: n is not None and round(n, 5) or None
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
)
@testing.requires.precision_generic_float_type
@@ -425,8 +393,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
self._do_test(
Float(None, decimal_return_scale=7, asdecimal=True),
[15.7563827, decimal.Decimal("15.7563827")],
- [decimal.Decimal("15.7563827"), ],
- check_scale=True
+ [decimal.Decimal("15.7563827")],
+ check_scale=True,
)
def test_numeric_as_decimal(self):
@@ -445,18 +413,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
@testing.requires.fetch_null_from_numeric
def test_numeric_null_as_decimal(self):
- self._do_test(
- Numeric(precision=8, scale=4),
- [None],
- [None],
- )
+ self._do_test(Numeric(precision=8, scale=4), [None], [None])
@testing.requires.fetch_null_from_numeric
def test_numeric_null_as_float(self):
self._do_test(
- Numeric(precision=8, scale=4, asdecimal=False),
- [None],
- [None],
+ Numeric(precision=8, scale=4, asdecimal=False), [None], [None]
)
@testing.requires.floats_to_four_decimals
@@ -472,15 +434,13 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
Float(precision=8),
[15.7563, decimal.Decimal("15.7563")],
[15.7563],
- filter_=lambda n: n is not None and round(n, 5) or None
+ filter_=lambda n: n is not None and round(n, 5) or None,
)
def test_float_coerce_round_trip(self):
expr = 15.7563
- val = testing.db.scalar(
- select([literal(expr)])
- )
+ val = testing.db.scalar(select([literal(expr)]))
eq_(val, expr)
# this does not work in MySQL, see #4036, however we choose not
@@ -491,34 +451,28 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
def test_decimal_coerce_round_trip(self):
expr = decimal.Decimal("15.7563")
- val = testing.db.scalar(
- select([literal(expr)])
- )
+ val = testing.db.scalar(select([literal(expr)]))
eq_(val, expr)
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
def test_decimal_coerce_round_trip_w_cast(self):
expr = decimal.Decimal("15.7563")
- val = testing.db.scalar(
- select([cast(expr, Numeric(10, 4))])
- )
+ val = testing.db.scalar(select([cast(expr, Numeric(10, 4))]))
eq_(val, expr)
@testing.requires.precision_numerics_general
def test_precision_decimal(self):
- numbers = set([
- decimal.Decimal("54.234246451650"),
- decimal.Decimal("0.004354"),
- decimal.Decimal("900.0"),
- ])
-
- self._do_test(
- Numeric(precision=18, scale=12),
- numbers,
- numbers,
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ]
)
+ self._do_test(Numeric(precision=18, scale=12), numbers, numbers)
+
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal(self):
"""test exceedingly small decimals.
@@ -528,25 +482,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
"""
- numbers = set([
- decimal.Decimal('1E-2'),
- decimal.Decimal('1E-3'),
- decimal.Decimal('1E-4'),
- decimal.Decimal('1E-5'),
- decimal.Decimal('1E-6'),
- decimal.Decimal('1E-7'),
- decimal.Decimal('1E-8'),
- decimal.Decimal("0.01000005940696"),
- decimal.Decimal("0.00000005940696"),
- decimal.Decimal("0.00000000000696"),
- decimal.Decimal("0.70000000000696"),
- decimal.Decimal("696E-12"),
- ])
- self._do_test(
- Numeric(precision=18, scale=14),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("1E-2"),
+ decimal.Decimal("1E-3"),
+ decimal.Decimal("1E-4"),
+ decimal.Decimal("1E-5"),
+ decimal.Decimal("1E-6"),
+ decimal.Decimal("1E-7"),
+ decimal.Decimal("1E-8"),
+ decimal.Decimal("0.01000005940696"),
+ decimal.Decimal("0.00000005940696"),
+ decimal.Decimal("0.00000000000696"),
+ decimal.Decimal("0.70000000000696"),
+ decimal.Decimal("696E-12"),
+ ]
)
+ self._do_test(Numeric(precision=18, scale=14), numbers, numbers)
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal_large(self):
@@ -554,41 +506,32 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
"""
- numbers = set([
- decimal.Decimal('4E+8'),
- decimal.Decimal("5748E+15"),
- decimal.Decimal('1.521E+15'),
- decimal.Decimal('00000000000000.1E+12'),
- ])
- self._do_test(
- Numeric(precision=25, scale=2),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("4E+8"),
+ decimal.Decimal("5748E+15"),
+ decimal.Decimal("1.521E+15"),
+ decimal.Decimal("00000000000000.1E+12"),
+ ]
)
+ self._do_test(Numeric(precision=25, scale=2), numbers, numbers)
@testing.requires.precision_numerics_many_significant_digits
def test_many_significant_digits(self):
- numbers = set([
- decimal.Decimal("31943874831932418390.01"),
- decimal.Decimal("319438950232418390.273596"),
- decimal.Decimal("87673.594069654243"),
- ])
- self._do_test(
- Numeric(precision=38, scale=12),
- numbers,
- numbers
+ numbers = set(
+ [
+ decimal.Decimal("31943874831932418390.01"),
+ decimal.Decimal("319438950232418390.273596"),
+ decimal.Decimal("87673.594069654243"),
+ ]
)
+ self._do_test(Numeric(precision=38, scale=12), numbers, numbers)
@testing.requires.precision_numerics_retains_significant_digits
def test_numeric_no_decimal(self):
- numbers = set([
- decimal.Decimal("1.000")
- ])
+ numbers = set([decimal.Decimal("1.000")])
self._do_test(
- Numeric(precision=5, scale=3),
- numbers,
- numbers,
- check_scale=True
+ Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
)
@@ -597,42 +540,32 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
- Table('boolean_table', metadata,
- Column('id', Integer, primary_key=True, autoincrement=False),
- Column('value', Boolean),
- Column('unconstrained_value', Boolean(create_constraint=False)),
- )
+ Table(
+ "boolean_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("value", Boolean),
+ Column("unconstrained_value", Boolean(create_constraint=False)),
+ )
def test_render_literal_bool(self):
- self._literal_round_trip(
- Boolean(),
- [True, False],
- [True, False]
- )
+ self._literal_round_trip(Boolean(), [True, False], [True, False])
def test_round_trip(self):
boolean_table = self.tables.boolean_table
config.db.execute(
boolean_table.insert(),
- {
- 'id': 1,
- 'value': True,
- 'unconstrained_value': False
- }
+ {"id": 1, "value": True, "unconstrained_value": False},
)
row = config.db.execute(
- select([
- boolean_table.c.value,
- boolean_table.c.unconstrained_value
- ])
+ select(
+ [boolean_table.c.value, boolean_table.c.unconstrained_value]
+ )
).first()
- eq_(
- row,
- (True, False)
- )
+ eq_(row, (True, False))
assert isinstance(row[0], bool)
def test_null(self):
@@ -640,24 +573,16 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
config.db.execute(
boolean_table.insert(),
- {
- 'id': 1,
- 'value': None,
- 'unconstrained_value': None
- }
+ {"id": 1, "value": None, "unconstrained_value": None},
)
row = config.db.execute(
- select([
- boolean_table.c.value,
- boolean_table.c.unconstrained_value
- ])
+ select(
+ [boolean_table.c.value, boolean_table.c.unconstrained_value]
+ )
).first()
- eq_(
- row,
- (None, None)
- )
+ eq_(row, (None, None))
def test_whereclause(self):
# testing "WHERE <column>" renders a compatible expression
@@ -667,92 +592,82 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
conn.execute(
boolean_table.insert(),
[
- {'id': 1, 'value': True, 'unconstrained_value': True},
- {'id': 2, 'value': False, 'unconstrained_value': False}
- ]
+ {"id": 1, "value": True, "unconstrained_value": True},
+ {"id": 2, "value": False, "unconstrained_value": False},
+ ],
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(boolean_table.c.value)
),
- 1
+ 1,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(
- boolean_table.c.unconstrained_value)
+ boolean_table.c.unconstrained_value
+ )
),
- 1
+ 1,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(~boolean_table.c.value)
),
- 2
+ 2,
)
eq_(
conn.scalar(
select([boolean_table.c.id]).where(
- ~boolean_table.c.unconstrained_value)
+ ~boolean_table.c.unconstrained_value
+ )
),
- 2
+ 2,
)
-
-
class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
- __requires__ = 'json_type',
+ __requires__ = ("json_type",)
__backend__ = True
datatype = JSON
- data1 = {
- "key1": "value1",
- "key2": "value2"
- }
+ data1 = {"key1": "value1", "key2": "value2"}
data2 = {
"Key 'One'": "value1",
"key two": "value2",
- "key three": "value ' three '"
+ "key three": "value ' three '",
}
data3 = {
"key1": [1, 2, 3],
"key2": ["one", "two", "three"],
- "key3": [{"four": "five"}, {"six": "seven"}]
+ "key3": [{"four": "five"}, {"six": "seven"}],
}
data4 = ["one", "two", "three"]
data5 = {
"nested": {
- "elem1": [
- {"a": "b", "c": "d"},
- {"e": "f", "g": "h"}
- ],
- "elem2": {
- "elem3": {"elem4": "elem5"}
- }
+ "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}],
+ "elem2": {"elem3": {"elem4": "elem5"}},
}
}
- data6 = {
- "a": 5,
- "b": "some value",
- "c": {"foo": "bar"}
- }
+ data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}}
@classmethod
def define_tables(cls, metadata):
- Table('data_table', metadata,
- Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False),
- Column('data', cls.datatype),
- Column('nulldata', cls.datatype(none_as_null=True))
- )
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
def test_round_trip_data1(self):
self._test_round_trip(self.data1)
@@ -761,99 +676,82 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
data_table = self.tables.data_table
config.db.execute(
- data_table.insert(),
- {'name': 'row1', 'data': data_element}
+ data_table.insert(), {"name": "row1", "data": data_element}
)
- row = config.db.execute(
- select([
- data_table.c.data,
- ])
- ).first()
+ row = config.db.execute(select([data_table.c.data])).first()
- eq_(row, (data_element, ))
+ eq_(row, (data_element,))
def test_round_trip_none_as_sql_null(self):
- col = self.tables.data_table.c['nulldata']
+ col = self.tables.data_table.c["nulldata"]
with config.db.connect() as conn:
conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": None}
+ self.tables.data_table.insert(), {"name": "r1", "data": None}
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(col.is_(null()))
+ select([self.tables.data_table.c.name]).where(
+ col.is_(null())
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def test_round_trip_json_null_as_json_null(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
with config.db.connect() as conn:
conn.execute(
self.tables.data_table.insert(),
- {"name": "r1", "data": JSON.NULL}
+ {"name": "r1", "data": JSON.NULL},
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(cast(col, String) == 'null')
+ select([self.tables.data_table.c.name]).where(
+ cast(col, String) == "null"
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def test_round_trip_none_as_json_null(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
with config.db.connect() as conn:
conn.execute(
- self.tables.data_table.insert(),
- {"name": "r1", "data": None}
+ self.tables.data_table.insert(), {"name": "r1", "data": None}
)
eq_(
conn.scalar(
- select([self.tables.data_table.c.name]).
- where(cast(col, String) == 'null')
+ select([self.tables.data_table.c.name]).where(
+ cast(col, String) == "null"
+ )
),
- "r1"
+ "r1",
)
- eq_(
- conn.scalar(
- select([col])
- ),
- None
- )
+ eq_(conn.scalar(select([col])), None)
def _criteria_fixture(self):
config.db.execute(
self.tables.data_table.insert(),
- [{"name": "r1", "data": self.data1},
- {"name": "r2", "data": self.data2},
- {"name": "r3", "data": self.data3},
- {"name": "r4", "data": self.data4},
- {"name": "r5", "data": self.data5},
- {"name": "r6", "data": self.data6}]
+ [
+ {"name": "r1", "data": self.data1},
+ {"name": "r2", "data": self.data2},
+ {"name": "r3", "data": self.data3},
+ {"name": "r4", "data": self.data4},
+ {"name": "r5", "data": self.data5},
+ {"name": "r6", "data": self.data6},
+ ],
)
def _test_index_criteria(self, crit, expected, test_literal=True):
@@ -861,20 +759,20 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
with config.db.connect() as conn:
stmt = select([self.tables.data_table.c.name]).where(crit)
- eq_(
- conn.scalar(stmt),
- expected
- )
+ eq_(conn.scalar(stmt), expected)
if test_literal:
- literal_sql = str(stmt.compile(
- config.db, compile_kwargs={"literal_binds": True}))
+ literal_sql = str(
+ stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}
+ )
+ )
eq_(conn.scalar(literal_sql), expected)
def test_crit_spaces_in_key(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
# limit the rows here to avoid PG error
# "cannot extract field from a non-object", which is
@@ -882,76 +780,74 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
self._test_index_criteria(
and_(
name.in_(["r1", "r2", "r3"]),
- cast(col["key two"], String) == '"value2"'
+ cast(col["key two"], String) == '"value2"',
),
- "r2"
+ "r2",
)
@config.requirements.json_array_indexes
def test_crit_simple_int(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
# limit the rows here to avoid PG error
# "cannot extract array element from a non-array", which is
# fixed in 9.4 but may exist in 9.3
self._test_index_criteria(
- and_(name == 'r4', cast(col[1], String) == '"two"'),
- "r4"
+ and_(name == "r4", cast(col[1], String) == '"two"'), "r4"
)
def test_crit_mixed_path(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- cast(col[("key3", 1, "six")], String) == '"seven"',
- "r3"
+ cast(col[("key3", 1, "six")], String) == '"seven"', "r3"
)
def test_crit_string_path(self):
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
cast(col[("nested", "elem2", "elem3", "elem4")], String)
== '"elem5"',
- "r5"
+ "r5",
)
def test_crit_against_string_basic(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["b"], String) == '"some value"'),
- "r6"
+ and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6"
)
def test_crit_against_string_coerce_type(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6',
- cast(col["b"], String) == type_coerce("some value", JSON)),
+ and_(
+ name == "r6",
+ cast(col["b"], String) == type_coerce("some value", JSON),
+ ),
"r6",
- test_literal=False
+ test_literal=False,
)
def test_crit_against_int_basic(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["a"], String) == '5'),
- "r6"
+ and_(name == "r6", cast(col["a"], String) == "5"), "r6"
)
def test_crit_against_int_coerce_type(self):
name = self.tables.data_table.c.name
- col = self.tables.data_table.c['data']
+ col = self.tables.data_table.c["data"]
self._test_index_criteria(
- and_(name == 'r6', cast(col["a"], String) == type_coerce(5, JSON)),
+ and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)),
"r6",
- test_literal=False
+ test_literal=False,
)
def test_unicode_round_trip(self):
@@ -961,17 +857,17 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
{
"name": "r1",
"data": {
- util.u('réveillé'): util.u('réveillé'),
- "data": {"k1": util.u('drôle')}
- }
- }
+ util.u("réveillé"): util.u("réveillé"),
+ "data": {"k1": util.u("drôle")},
+ },
+ },
)
eq_(
conn.scalar(select([self.tables.data_table.c.data])),
{
- util.u('réveillé'): util.u('réveillé'),
- "data": {"k1": util.u('drôle')}
+ util.u("réveillé"): util.u("réveillé"),
+ "data": {"k1": util.u("drôle")},
},
)
@@ -986,7 +882,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
s = Session(testing.db)
- d1 = Data(name='d1', data=None, nulldata=None)
+ d1 = Data(name="d1", data=None, nulldata=None)
s.add(d1)
s.commit()
@@ -995,24 +891,46 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
)
eq_(
s.query(
- cast(self.tables.data_table.c.data, String(convert_unicode="force")),
- cast(self.tables.data_table.c.nulldata, String)
- ).filter(self.tables.data_table.c.name == 'd1').first(),
- ("null", None)
+ cast(
+ self.tables.data_table.c.data,
+ String(convert_unicode="force"),
+ ),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
)
eq_(
s.query(
- cast(self.tables.data_table.c.data, String(convert_unicode="force")),
- cast(self.tables.data_table.c.nulldata, String)
- ).filter(self.tables.data_table.c.name == 'd2').first(),
- ("null", None)
- )
-
-
-__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'JSONTest',
- 'DateTest', 'DateTimeTest', 'TextTest',
- 'NumericTest', 'IntegerTest',
- 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest',
- 'TimeMicrosecondsTest', 'TimestampMicrosecondsTest', 'TimeTest',
- 'DateTimeMicrosecondsTest',
- 'DateHistoricTest', 'StringTest', 'BooleanTest')
+ cast(
+ self.tables.data_table.c.data,
+ String(convert_unicode="force"),
+ ),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
+ )
+
+
+__all__ = (
+ "UnicodeVarcharTest",
+ "UnicodeTextTest",
+ "JSONTest",
+ "DateTest",
+ "DateTimeTest",
+ "TextTest",
+ "NumericTest",
+ "IntegerTest",
+ "DateTimeHistoricTest",
+ "DateTimeCoercedToDateTimeTest",
+ "TimeMicrosecondsTest",
+ "TimestampMicrosecondsTest",
+ "TimeTest",
+ "DateTimeMicrosecondsTest",
+ "DateHistoricTest",
+ "StringTest",
+ "BooleanTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
index e4c61e74a..b232c3a78 100644
--- a/lib/sqlalchemy/testing/suite/test_update_delete.py
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -6,15 +6,17 @@ from ..schema import Table, Column
class SimpleUpdateDeleteTest(fixtures.TablesTest):
- run_deletes = 'each'
+ run_deletes = "each"
__backend__ = True
@classmethod
def define_tables(cls, metadata):
- Table('plain_pk', metadata,
- Column('id', Integer, primary_key=True),
- Column('data', String(50))
- )
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
@classmethod
def insert_data(cls):
@@ -24,40 +26,29 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest):
{"id": 1, "data": "d1"},
{"id": 2, "data": "d2"},
{"id": 3, "data": "d3"},
- ]
+ ],
)
def test_update(self):
t = self.tables.plain_pk
- r = config.db.execute(
- t.update().where(t.c.id == 2),
- data="d2_new"
- )
+ r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new")
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
- [
- (1, "d1"),
- (2, "d2_new"),
- (3, "d3")
- ]
+ [(1, "d1"), (2, "d2_new"), (3, "d3")],
)
def test_delete(self):
t = self.tables.plain_pk
- r = config.db.execute(
- t.delete().where(t.c.id == 2)
- )
+ r = config.db.execute(t.delete().where(t.c.id == 2))
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
- [
- (1, "d1"),
- (3, "d3")
- ]
+ [(1, "d1"), (3, "d3")],
)
-__all__ = ('SimpleUpdateDeleteTest', )
+
+__all__ = ("SimpleUpdateDeleteTest",)
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index 409d3bda5..5b015d214 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -14,6 +14,7 @@ import sys
import types
if jython:
+
def jython_gc_collect(*args):
"""aggressive gc.collect for tests."""
gc.collect()
@@ -25,9 +26,11 @@ if jython:
# "lazy" gc, for VM's that don't GC on refcount == 0
gc_collect = lazy_gc = jython_gc_collect
elif pypy:
+
def pypy_gc_collect(*args):
gc.collect()
gc.collect()
+
gc_collect = lazy_gc = pypy_gc_collect
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
@@ -42,11 +45,13 @@ def picklers():
if py2k:
try:
import cPickle
+
picklers.add(cPickle)
except ImportError:
pass
import pickle
+
picklers.add(pickle)
# yes, this thing needs this much testing
@@ -60,9 +65,9 @@ def round_decimal(value, prec):
return round(value, prec)
# can also use shift() here but that is 2.6 only
- return (value * decimal.Decimal("1" + "0" * prec)
- ).to_integral(decimal.ROUND_FLOOR) / \
- pow(10, prec)
+ return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
+ decimal.ROUND_FLOOR
+ ) / pow(10, prec)
class RandomSet(set):
@@ -137,8 +142,9 @@ def function_named(fn, name):
try:
fn.__name__ = name
except TypeError:
- fn = types.FunctionType(fn.__code__, fn.__globals__, name,
- fn.__defaults__, fn.__closure__)
+ fn = types.FunctionType(
+ fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
+ )
return fn
@@ -190,7 +196,7 @@ def provide_metadata(fn, *args, **kw):
metadata = schema.MetaData(config.db)
self = args[0]
- prev_meta = getattr(self, 'metadata', None)
+ prev_meta = getattr(self, "metadata", None)
self.metadata = metadata
try:
return fn(*args, **kw)
@@ -213,8 +219,8 @@ def force_drop_names(*names):
try:
return fn(*args, **kw)
finally:
- drop_all_tables(
- config.db, inspect(config.db), include_names=names)
+ drop_all_tables(config.db, inspect(config.db), include_names=names)
+
return go
@@ -234,8 +240,13 @@ class adict(dict):
def drop_all_tables(engine, inspector, schema=None, include_names=None):
- from sqlalchemy import Column, Table, Integer, MetaData, \
- ForeignKeyConstraint
+ from sqlalchemy import (
+ Column,
+ Table,
+ Integer,
+ MetaData,
+ ForeignKeyConstraint,
+ )
from sqlalchemy.schema import DropTable, DropConstraint
if include_names is not None:
@@ -243,30 +254,35 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None):
with engine.connect() as conn:
for tname, fkcs in reversed(
- inspector.get_sorted_table_and_fkc_names(schema=schema)):
+ inspector.get_sorted_table_and_fkc_names(schema=schema)
+ ):
if tname:
if include_names is not None and tname not in include_names:
continue
- conn.execute(DropTable(
- Table(tname, MetaData(), schema=schema)
- ))
+ conn.execute(
+ DropTable(Table(tname, MetaData(), schema=schema))
+ )
elif fkcs:
if not engine.dialect.supports_alter:
continue
for tname, fkc in fkcs:
- if include_names is not None and \
- tname not in include_names:
+ if (
+ include_names is not None
+ and tname not in include_names
+ ):
continue
tb = Table(
- tname, MetaData(),
- Column('x', Integer),
- Column('y', Integer),
- schema=schema
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
+ )
)
- conn.execute(DropConstraint(
- ForeignKeyConstraint(
- [tb.c.x], [tb.c.y], name=fkc)
- ))
def teardown_events(event_cls):
@@ -276,5 +292,5 @@ def teardown_events(event_cls):
return fn(*arg, **kw)
finally:
event_cls._clear()
- return decorate
+ return decorate
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
index 46e7c54db..e0101b14d 100644
--- a/lib/sqlalchemy/testing/warnings.py
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -15,17 +15,20 @@ from . import assertions
def setup_filters():
"""Set global warning behavior for the test suite."""
- warnings.filterwarnings('ignore',
- category=sa_exc.SAPendingDeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SAWarning)
+ warnings.filterwarnings(
+ "ignore", category=sa_exc.SAPendingDeprecationWarning
+ )
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
# some selected deprecations...
- warnings.filterwarnings('error', category=DeprecationWarning)
+ warnings.filterwarnings("error", category=DeprecationWarning)
warnings.filterwarnings(
- "ignore", category=DeprecationWarning, message=".*StopIteration")
+ "ignore", category=DeprecationWarning, message=".*StopIteration"
+ )
warnings.filterwarnings(
- "ignore", category=DeprecationWarning, message=".*inspect.getargspec")
+ "ignore", category=DeprecationWarning, message=".*inspect.getargspec"
+ )
def assert_warnings(fn, warning_msgs, regex=False):
@@ -36,6 +39,6 @@ def assert_warnings(fn, warning_msgs, regex=False):
"""
with assertions._expect_warnings(
- sa_exc.SAWarning, warning_msgs, regex=regex):
+ sa_exc.SAWarning, warning_msgs, regex=regex
+ ):
return fn()
-