diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
33 files changed, 2399 insertions, 2114 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 413a492b8..f46ca4528 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -10,23 +10,62 @@ from .warnings import assert_warnings from . import config -from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ - fails_on, fails_on_everything_except, skip, only_on, exclude, \ - against as _against, _server_version, only_if, fails +from .exclusions import ( + db_spec, + _is_excluded, + fails_if, + skip_if, + future, + fails_on, + fails_on_everything_except, + skip, + only_on, + exclude, + against as _against, + _server_version, + only_if, + fails, +) def against(*queries): return _against(config._current, *queries) -from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ - eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \ - assert_raises_message, AssertsCompiledSQL, ComparesTables, \ - AssertsExecutionResults, expect_deprecated, expect_warnings, \ - in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false -from .util import run_as_contextmanager, rowset, fail, \ - provide_metadata, adict, force_drop_names, \ - teardown_events +from .assertions import ( + emits_warning, + emits_warning_on, + uses_deprecated, + eq_, + ne_, + le_, + is_, + is_not_, + startswith_, + assert_raises, + assert_raises_message, + AssertsCompiledSQL, + ComparesTables, + AssertsExecutionResults, + expect_deprecated, + expect_warnings, + in_, + not_in_, + eq_ignore_whitespace, + eq_regex, + is_true, + is_false, +) + +from .util import ( + run_as_contextmanager, + rowset, + fail, + provide_metadata, + adict, + force_drop_names, + teardown_events, +) crashes = skip diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e42376921..73ab4556a 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -86,6 +86,7 @@ def emits_warning_on(db, *messages): were in fact seen. """ + @decorator def decorate(fn, *args, **kw): with expect_warnings_on(db, assert_=False, *messages): @@ -114,12 +115,14 @@ def uses_deprecated(*messages): def decorate(fn, *args, **kw): with expect_deprecated(*messages, assert_=False): return fn(*args, **kw) + return decorate @contextlib.contextmanager -def _expect_warnings(exc_cls, messages, regex=True, assert_=True, - py2konly=False): +def _expect_warnings( + exc_cls, messages, regex=True, assert_=True, py2konly=False +): if regex: filters = [re.compile(msg, re.I | re.S) for msg in messages] @@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, return for filter_ in filters: - if (regex and filter_.match(msg)) or \ - (not regex and filter_ == msg): + if (regex and filter_.match(msg)) or ( + not regex and filter_ == msg + ): seen.discard(filter_) break else: @@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, yield if assert_ and (not py2konly or not compat.py3k): - assert not seen, "Warnings were not seen: %s" % \ - ", ".join("%r" % (s.pattern if regex else s) for s in seen) + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) def global_cleanup_assertions(): @@ -170,6 +175,7 @@ def global_cleanup_assertions(): """ _assert_no_stray_pool_connections() + _STRAY_CONNECTION_FAILURES = 0 @@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections(): # OK, let's be somewhat forgiving. _STRAY_CONNECTION_FAILURES += 1 - print("Encountered a stray connection in test cleanup: %s" - % str(pool._refs)) + print( + "Encountered a stray connection in test cleanup: %s" + % str(pool._refs) + ) # then do a real GC sweep. We shouldn't even be here # so a single sweep should really be doing it, otherwise # there's probably a real unreachable cycle somewhere. @@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections(): pool._refs.clear() _STRAY_CONNECTION_FAILURES = 0 warnings.warn( - "Stray connection refused to leave " - "after gc.collect(): %s" % err) + "Stray connection refused to leave " "after gc.collect(): %s" % err + ) elif _STRAY_CONNECTION_FAILURES > 10: assert False, "Encountered more than 10 stray connections" _STRAY_CONNECTION_FAILURES = 0 @@ -263,14 +271,16 @@ def not_in_(a, b, msg=None): def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( - a, fragment) + a, + fragment, + ) def eq_ignore_whitespace(a, b, msg=None): - a = re.sub(r'^\s+?|\n', "", a) - a = re.sub(r' {2,}', " ", a) - b = re.sub(r'^\s+?|\n', "", b) - b = re.sub(r' {2,}', " ", b) + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) assert a == b, msg or "%r != %r" % (a, b) @@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls as e: - assert re.search( - msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) - print(util.text_type(e).encode('utf-8')) + assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( + msg, + e, + ) + print(util.text_type(e).encode("utf-8")) + class AssertsCompiledSQL(object): - def assert_compile(self, clause, result, params=None, - checkparams=None, dialect=None, - checkpositional=None, - check_prefetch=None, - use_default_dialect=False, - allow_dialect_select=False, - literal_binds=False, - schema_translate_map=None): + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + literal_binds=False, + schema_translate_map=None, + ): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: dialect = None else: if dialect is None: - dialect = getattr(self, '__dialect__', None) + dialect = getattr(self, "__dialect__", None) if dialect is None: dialect = config.db.dialect - elif dialect == 'default': + elif dialect == "default": dialect = default.DefaultDialect() - elif dialect == 'default_enhanced': + elif dialect == "default_enhanced": dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() @@ -325,13 +344,13 @@ class AssertsCompiledSQL(object): compile_kwargs = {} if schema_translate_map: - kw['schema_translate_map'] = schema_translate_map + kw["schema_translate_map"] = schema_translate_map if params is not None: - kw['column_keys'] = list(params) + kw["column_keys"] = list(params) if literal_binds: - compile_kwargs['literal_binds'] = True + compile_kwargs["literal_binds"] = True if isinstance(clause, orm.Query): context = clause._compile_context() @@ -343,25 +362,27 @@ class AssertsCompiledSQL(object): clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: - kw['compile_kwargs'] = compile_kwargs + kw["compile_kwargs"] = compile_kwargs c = clause.compile(dialect=dialect, **kw) - param_str = repr(getattr(c, 'params', {})) + param_str = repr(getattr(c, "params", {})) if util.py3k: - param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + param_str = param_str.encode("utf-8").decode("ascii", "ignore") print( - ("\nSQL String:\n" + - util.text_type(c) + - param_str).encode('utf-8')) + ("\nSQL String:\n" + util.text_type(c) + param_str).encode( + "utf-8" + ) + ) else: print( - "\nSQL String:\n" + - util.text_type(c).encode('utf-8') + - param_str) + "\nSQL String:\n" + + util.text_type(c).encode("utf-8") + + param_str + ) - cc = re.sub(r'[\n\t]', '', util.text_type(c)) + cc = re.sub(r"[\n\t]", "", util.text_type(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -375,7 +396,6 @@ class AssertsCompiledSQL(object): class ComparesTables(object): - def assert_tables_equal(self, table, reflected_table, strict_types=False): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): @@ -386,8 +406,10 @@ class ComparesTables(object): if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" - assert isinstance(reflected_c.type, type(c.type)), \ - msg % (reflected_c.type, c.type) + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) else: self.assert_types_base(reflected_c, c) @@ -396,20 +418,22 @@ class ComparesTables(object): eq_( {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys} + {f.column.name for f in reflected_c.foreign_keys}, ) if c.server_default: - assert isinstance(reflected_c.server_default, - schema.FetchedValue) + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): - assert c1.type._compare_type_affinity(c2.type),\ - "On column %r, type '%s' doesn't correspond to type '%s'" % \ - (c1.name, c1.type, c2.type) + assert c1.type._compare_type_affinity(c2.type), ( + "On column %r, type '%s' doesn't correspond to type '%s'" + % (c1.name, c1.type, c2.type) + ) class AssertsExecutionResults(object): @@ -419,15 +443,19 @@ class AssertsExecutionResults(object): self.assert_list(result, class_, objects) def assert_list(self, result, class_, list): - self.assert_(len(result) == len(list), - "result list is not the same size as test list, " + - "for class " + class_.__name__) + self.assert_( + len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) for i in range(0, len(list)): self.assert_row(class_, result[i], list[i]) def assert_row(self, class_, rowobj, desc): - self.assert_(rowobj.__class__ is class_, - "item class is not " + repr(class_)) + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) for key, value in desc.items(): if isinstance(value, tuple): if isinstance(value[1], list): @@ -435,9 +463,11 @@ class AssertsExecutionResults(object): else: self.assert_row(value[0], getattr(rowobj, key), value[1]) else: - self.assert_(getattr(rowobj, key) == value, - "attribute %s value %s does not match %s" % ( - key, getattr(rowobj, key), value)) + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) def assert_unordered_result(self, result, cls, *expected): """As assert_result, but the order of objects is not considered. @@ -453,14 +483,19 @@ class AssertsExecutionResults(object): found = util.IdentitySet(result) expected = {immutabledict(e) for e in expected} - for wrong in util.itertools_filterfalse(lambda o: - isinstance(o, cls), found): - fail('Unexpected type "%s", expected "%s"' % ( - type(wrong).__name__, cls.__name__)) + for wrong in util.itertools_filterfalse( + lambda o: isinstance(o, cls), found + ): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) if len(found) != len(expected): - fail('Unexpected object count "%s", expected "%s"' % ( - len(found), len(expected))) + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) NOVALUE = object() @@ -469,7 +504,8 @@ class AssertsExecutionResults(object): if isinstance(value, tuple): try: self.assert_unordered_result( - getattr(obj, key), value[0], *value[1]) + getattr(obj, key), value[0], *value[1] + ) except AssertionError: return False else: @@ -484,8 +520,9 @@ class AssertsExecutionResults(object): break else: fail( - "Expected %s instance with attributes %s not found." % ( - cls.__name__, repr(expected_item))) + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) return True def sql_execution_asserter(self, db=None): @@ -505,9 +542,9 @@ class AssertsExecutionResults(object): newrules = [] for rule in rules: if isinstance(rule, dict): - newrule = assertsql.AllOf(*[ - assertsql.CompiledSQL(k, v) for k, v in rule.items() - ]) + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) else: newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) @@ -516,7 +553,8 @@ class AssertsExecutionResults(object): def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( - db, callable_, assertsql.CountStatements(count)) + db, callable_, assertsql.CountStatements(count) + ) def assert_multiple_sql_count(self, dbs, callable_, counts): recs = [ diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 7a525589d..d8e924cb6 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -26,8 +26,10 @@ class AssertRule(object): pass def no_more_statements(self): - assert False, 'All statements are complete, but pending '\ - 'assertion rules remain' + assert False, ( + "All statements are complete, but pending " + "assertion rules remain" + ) class SQLMatchRule(AssertRule): @@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule): def process_statement(self, execute_observed): stmt = execute_observed.statements[0] if self.statement != stmt.statement or ( - self.params is not None and self.params != stmt.parameters): - self.errormessage = \ - "Testing for exact SQL %s parameters %s received %s %s" % ( - self.statement, self.params, - stmt.statement, stmt.parameters + self.params is not None and self.params != stmt.parameters + ): + self.errormessage = ( + "Testing for exact SQL %s parameters %s received %s %s" + % ( + self.statement, + self.params, + stmt.statement, + stmt.parameters, ) + ) else: execute_observed.statements.pop(0) self.is_consumed = True @@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - - def __init__(self, statement, params=None, dialect='default'): + def __init__(self, statement, params=None, dialect="default"): self.statement = statement self.params = params self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - if self.dialect == 'default': + if self.dialect == "default": return DefaultDialect() else: # ugh - if self.dialect == 'postgresql': - params = {'implicit_returning': True} + if self.dialect == "postgresql": + params = {"implicit_returning": True} else: params = {} return url.URL(self.dialect).get_dialect()(**params) @@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule): context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): - compiled = \ - context.compiled.statement.compile( - dialect=compare_dialect, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), + ) else: - compiled = ( - context.compiled.statement.compile( - dialect=compare_dialect, - column_keys=context.compiled.column_keys, - inline=context.compiled.inline, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + inline=context.compiled.inline, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), ) - _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) + _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: _received_parameters = [compiled.construct_params()] else: _received_parameters = [ - compiled.construct_params(m) for m in parameters] + compiled.construct_params(m) for m in parameters + ] return _received_statement, _received_parameters def process_statement(self, execute_observed): context = execute_observed.context - _received_statement, _received_parameters = \ - self._received_statement(execute_observed) + _received_statement, _received_parameters = self._received_statement( + execute_observed + ) params = self._all_params(context) equivalent = self._compare_sql(execute_observed, _received_statement) @@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule): for param_key in param: # a key in param did not match current # 'received' - if param_key not in received or \ - received[param_key] != param[param_key]: + if ( + param_key not in received + or received[param_key] != param[param_key] + ): break else: # all keys in param matched 'received'; @@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule): self.errormessage = None else: self.errormessage = self._failure_message(params) % { - 'received_statement': _received_statement, - 'received_parameters': _received_parameters + "received_statement": _received_statement, + "received_parameters": _received_parameters, } def _all_params(self, context): @@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule): def _failure_message(self, expected_params): return ( - 'Testing for compiled statement %r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.statement.replace('%', '%%'), expected_params - ) + "Testing for compiled statement %r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" + % (self.statement.replace("%", "%%"), expected_params) ) @@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL): self.regex = re.compile(regex) self.orig_regex = regex self.params = params - self.dialect = 'default' + self.dialect = "default" def _failure_message(self, expected_params): return ( - 'Testing for compiled statement ~%r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.orig_regex, expected_params - ) + "Testing for compiled statement ~%r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" % (self.orig_regex, expected_params) ) def _compare_sql(self, execute_observed, received_statement): @@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL): return execute_observed.context.dialect def _compare_no_space(self, real_stmt, received_stmt): - stmt = re.sub(r'[\n\t]', '', real_stmt) + stmt = re.sub(r"[\n\t]", "", real_stmt) return received_stmt == stmt def _received_statement(self, execute_observed): - received_stmt, received_params = super(DialectSQL, self).\ - _received_statement(execute_observed) + received_stmt, received_params = super( + DialectSQL, self + )._received_statement(execute_observed) # TODO: why do we need this part? for real_stmt in execute_observed.statements: @@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL): else: raise AssertionError( "Can't locate compiled statement %r in list of " - "statements actually invoked" % received_stmt) + "statements actually invoked" % received_stmt + ) return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle - if paramstyle == 'pyformat': - stmt = re.sub( - r':([\w_]+)', r"%(\1)s", stmt) + if paramstyle == "pyformat": + stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) else: # positional params repl = None - if paramstyle == 'qmark': + if paramstyle == "qmark": repl = "?" - elif paramstyle == 'format': + elif paramstyle == "format": repl = r"%s" - elif paramstyle == 'numeric': + elif paramstyle == "numeric": repl = None - stmt = re.sub(r':([\w_]+)', repl, stmt) + stmt = re.sub(r":([\w_]+)", repl, stmt) return received_statement == stmt class CountStatements(AssertRule): - def __init__(self, count): self.count = count self._statement_count = 0 @@ -256,12 +264,13 @@ class CountStatements(AssertRule): def no_more_statements(self): if self.count != self._statement_count: - assert False, 'desired statement count %d does not match %d' \ - % (self.count, self._statement_count) + assert False, "desired statement count %d does not match %d" % ( + self.count, + self._statement_count, + ) class AllOf(AssertRule): - def __init__(self, *rules): self.rules = set(rules) @@ -283,7 +292,6 @@ class AllOf(AssertRule): class EachOf(AssertRule): - def __init__(self, *rules): self.rules = list(rules) @@ -309,7 +317,6 @@ class EachOf(AssertRule): class Or(AllOf): - def process_statement(self, execute_observed): for rule in self.rules: rule.process_statement(execute_observed) @@ -331,7 +338,8 @@ class SQLExecuteObserved(object): class SQLCursorExecuteObserved( collections.namedtuple( "SQLCursorExecuteObserved", - ["statement", "parameters", "context", "executemany"]) + ["statement", "parameters", "context", "executemany"], + ) ): pass @@ -374,21 +382,25 @@ def assert_engine(engine): orig[:] = clauseelement, multiparams, params @event.listens_for(engine, "after_cursor_execute") - def cursor_execute(conn, cursor, statement, parameters, - context, executemany): + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): if not context: return # then grab real cursor statements and associate them all # around a single context - if asserter.accumulated and \ - asserter.accumulated[-1].context is context: + if ( + asserter.accumulated + and asserter.accumulated[-1].context is context + ): obs = asserter.accumulated[-1] else: obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) asserter.accumulated.append(obs) obs.statements.append( SQLCursorExecuteObserved( - statement, parameters, context, executemany) + statement, parameters, context, executemany + ) ) try: diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index e9cfb3de9..1ff282af5 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -64,8 +64,9 @@ class Config(object): assert _current, "Can't push without a default Config set up" cls.push( Config( - db, _current.db_opts, _current.options, _current.file_config), - namespace + db, _current.db_opts, _current.options, _current.file_config + ), + namespace, ) @classmethod @@ -94,4 +95,3 @@ class Config(object): def skip_test(msg): raise _skip_test_exception(msg) - diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index d17e30edf..074e3b338 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -16,7 +16,6 @@ import warnings class ConnectionKiller(object): - def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() self.testing_engines = weakref.WeakKeyDictionary() @@ -39,8 +38,8 @@ class ConnectionKiller(object): fn() except Exception as e: warnings.warn( - "testing_reaper couldn't " - "rollback/close connection: %s" % e) + "testing_reaper couldn't " "rollback/close connection: %s" % e + ) def rollback_all(self): for rec in list(self.proxy_refs): @@ -97,18 +96,19 @@ class ConnectionKiller(object): if rec.is_valid: assert False + testing_reaper = ConnectionKiller() def drop_all_tables(metadata, bind): testing_reaper.close_all() - if hasattr(bind, 'close'): + if hasattr(bind, "close"): bind.close() if not config.db.dialect.supports_alter: from . import assertions - with assertions.expect_warnings( - "Can't sort tables", assert_=False): + + with assertions.expect_warnings("Can't sort tables", assert_=False): metadata.drop_all(bind) else: metadata.drop_all(bind) @@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw): def all_dialects(exclude=None): import sqlalchemy.databases as d + for name in d.__all__: # TEMPORARY if exclude and name in exclude: continue mod = getattr(d, name, None) if not mod: - mod = getattr(__import__( - 'sqlalchemy.databases.%s' % name).databases, name) + mod = getattr( + __import__("sqlalchemy.databases.%s" % name).databases, name + ) yield mod.dialect() class ReconnectFixture(object): - def __init__(self, dbapi): self.dbapi = dbapi self.connections = [] @@ -191,8 +192,8 @@ class ReconnectFixture(object): fn() except Exception as e: warnings.warn( - "ReconnectFixture couldn't " - "close connection: %s" % e) + "ReconnectFixture couldn't " "close connection: %s" % e + ) def shutdown(self, stop=False): # TODO: this doesn't cover all cases @@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None): dbapi = config.db.dialect.dbapi if not options: options = {} - options['module'] = ReconnectFixture(dbapi) + options["module"] = ReconnectFixture(dbapi) engine = testing_engine(url, options) _dispose = engine.dispose @@ -238,7 +239,7 @@ def testing_engine(url=None, options=None): if not options: use_reaper = True else: - use_reaper = options.pop('use_reaper', True) + use_reaper = options.pop("use_reaper", True) url = url or config.db.url @@ -253,15 +254,15 @@ def testing_engine(url=None, options=None): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 if use_reaper: - event.listen(engine.pool, 'connect', testing_reaper.connect) - event.listen(engine.pool, 'checkout', testing_reaper.checkout) - event.listen(engine.pool, 'invalidate', testing_reaper.invalidate) + event.listen(engine.pool, "connect", testing_reaper.connect) + event.listen(engine.pool, "checkout", testing_reaper.checkout) + event.listen(engine.pool, "invalidate", testing_reaper.invalidate) testing_reaper.add_engine(engine) return engine @@ -290,19 +291,17 @@ def mock_engine(dialect_name=None): buffer.append(sql) def assert_sql(stmts): - recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] assert recv == stmts, recv def print_sql(): d = engine.dialect - return "\n".join( - str(s.compile(dialect=d)) - for s in engine.mock - ) - - engine = create_engine(dialect_name + '://', - strategy='mock', executor=executor) - assert not hasattr(engine, 'mock') + return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + + engine = create_engine( + dialect_name + "://", strategy="mock", executor=executor + ) + assert not hasattr(engine, "mock") engine.mock = buffer engine.assert_sql = assert_sql engine.print_sql = print_sql @@ -358,14 +357,15 @@ class DBAPIProxyConnection(object): return getattr(self.conn, key) -def proxying_engine(conn_cls=DBAPIProxyConnection, - cursor_cls=DBAPIProxyCursor): +def proxying_engine( + conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor +): """Produce an engine that provides proxy hooks for common methods. """ + def mock_conn(): return conn_cls(config.db, cursor_cls) - return testing_engine(options={'creator': mock_conn}) - + return testing_engine(options={"creator": mock_conn}) diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index b634735fc..42c42149c 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -12,7 +12,6 @@ _repr_stack = set() class BasicEntity(object): - def __init__(self, **kw): for key, value in kw.items(): setattr(self, key, value) @@ -24,17 +23,22 @@ class BasicEntity(object): try: return "%s(%s)" % ( (self.__class__.__name__), - ', '.join(["%s=%r" % (key, getattr(self, key)) - for key in sorted(self.__dict__.keys()) - if not key.startswith('_')])) + ", ".join( + [ + "%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith("_") + ] + ), + ) finally: _repr_stack.remove(id(self)) + _recursion_stack = set() class ComparableEntity(BasicEntity): - def __hash__(self): return hash(self.__class__) @@ -75,7 +79,7 @@ class ComparableEntity(BasicEntity): b = other for attr in list(a.__dict__): - if attr.startswith('_'): + if attr.startswith("_"): continue value = getattr(a, attr) @@ -85,9 +89,10 @@ class ComparableEntity(BasicEntity): except (AttributeError, sa_exc.UnboundExecutionError): return False - if hasattr(value, '__iter__'): - if hasattr(value, '__getitem__') and not hasattr( - value, 'keys'): + if hasattr(value, "__iter__"): + if hasattr(value, "__getitem__") and not hasattr( + value, "keys" + ): if list(value) != list(battr): return False else: diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 512fffb3b..9ed9e42c3 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -16,6 +16,7 @@ from . import config from .. import util from ..util import decorator + def skip_if(predicate, reason=None): rule = compound() pred = _as_predicate(predicate, reason) @@ -70,15 +71,15 @@ class compound(object): def matching_config_reasons(self, config): return [ - predicate._as_string(config) for predicate - in self.skips.union(self.fails) + predicate._as_string(config) + for predicate in self.skips.union(self.fails) if predicate(config) ] def include_test(self, include_tags, exclude_tags): return bool( - not self.tags.intersection(exclude_tags) and - (not include_tags or self.tags.intersection(include_tags)) + not self.tags.intersection(exclude_tags) + and (not include_tags or self.tags.intersection(include_tags)) ) def _extend(self, other): @@ -87,13 +88,14 @@ class compound(object): self.tags.update(other.tags) def __call__(self, fn): - if hasattr(fn, '_sa_exclusion_extend'): + if hasattr(fn, "_sa_exclusion_extend"): fn._sa_exclusion_extend._extend(self) return fn @decorator def decorate(fn, *args, **kw): return self._do(config._current, fn, *args, **kw) + decorated = decorate(fn) decorated._sa_exclusion_extend = self return decorated @@ -113,10 +115,7 @@ class compound(object): def _do(self, cfg, fn, *args, **kw): for skip in self.skips: if skip(cfg): - msg = "'%s' : %s" % ( - fn.__name__, - skip._as_string(cfg) - ) + msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg)) config.skip_test(msg) try: @@ -127,16 +126,20 @@ class compound(object): self._expect_success(cfg, name=fn.__name__) return return_value - def _expect_failure(self, config, ex, name='block'): + def _expect_failure(self, config, ex, name="block"): for fail in self.fails: if fail(config): - print(("%s failed as expected (%s): %s " % ( - name, fail._as_string(config), str(ex)))) + print( + ( + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), str(ex)) + ) + ) break else: util.raise_from_cause(ex) - def _expect_success(self, config, name='block'): + def _expect_success(self, config, name="block"): if not self.fails: return for fail in self.fails: @@ -144,13 +147,12 @@ class compound(object): break else: raise AssertionError( - "Unexpected success for '%s' (%s)" % - ( + "Unexpected success for '%s' (%s)" + % ( name, " and ".join( - fail._as_string(config) - for fail in self.fails - ) + fail._as_string(config) for fail in self.fails + ), ) ) @@ -186,21 +188,24 @@ class Predicate(object): return predicate elif isinstance(predicate, (list, set)): return OrPredicate( - [cls.as_predicate(pred) for pred in predicate], - description) + [cls.as_predicate(pred) for pred in predicate], description + ) elif isinstance(predicate, tuple): return SpecPredicate(*predicate) elif isinstance(predicate, util.string_types): tokens = re.match( - r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate) + r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate + ) if not tokens: raise ValueError( - "Couldn't locate DB name in predicate: %r" % predicate) + "Couldn't locate DB name in predicate: %r" % predicate + ) db = tokens.group(1) op = tokens.group(2) spec = ( tuple(int(d) for d in tokens.group(3).split(".")) - if tokens.group(3) else None + if tokens.group(3) + else None ) return SpecPredicate(db, op, spec, description=description) @@ -215,11 +220,13 @@ class Predicate(object): bool_ = not negate return self.description % { "driver": config.db.url.get_driver_name() - if config else "<no driver>", + if config + else "<no driver>", "database": config.db.url.get_backend_name() - if config else "<no database>", + if config + else "<no database>", "doesnt_support": "doesn't support" if bool_ else "does support", - "does_support": "does support" if bool_ else "doesn't support" + "does_support": "does support" if bool_ else "doesn't support", } def _as_string(self, config=None, negate=False): @@ -246,21 +253,21 @@ class SpecPredicate(Predicate): self.description = description _ops = { - '<': operator.lt, - '>': operator.gt, - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '>=': operator.ge, - 'in': operator.contains, - 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + "!=": operator.ne, + "<=": operator.le, + ">=": operator.ge, + "in": operator.contains, + "between": lambda val, pair: val >= pair[0] and val <= pair[1], } def __call__(self, config): engine = config.db if "+" in self.db: - dialect, driver = self.db.split('+') + dialect, driver = self.db.split("+") else: dialect, driver = self.db, None @@ -273,8 +280,9 @@ class SpecPredicate(Predicate): assert driver is None, "DBAPI version specs not supported yet" version = _server_version(engine) - oper = hasattr(self.op, '__call__') and self.op \ - or self._ops[self.op] + oper = ( + hasattr(self.op, "__call__") and self.op or self._ops[self.op] + ) return oper(version, self.spec) else: return True @@ -289,17 +297,9 @@ class SpecPredicate(Predicate): return "%s" % self.db else: if negate: - return "not %s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "not %s %s %s" % (self.db, self.op, self.spec) else: - return "%s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "%s %s %s" % (self.db, self.op, self.spec) class LambdaPredicate(Predicate): @@ -356,8 +356,9 @@ class OrPredicate(Predicate): conjunction = " and " else: conjunction = " or " - return conjunction.join(p._as_string(config, negate=negate) - for p in self.predicates) + return conjunction.join( + p._as_string(config, negate=negate) for p in self.predicates + ) def _negation_str(self, config): if self.description is not None: @@ -387,7 +388,7 @@ def _server_version(engine): # force metadata to be retrieved conn = engine.connect() - version = getattr(engine.dialect, 'server_version_info', None) + version = getattr(engine.dialect, "server_version_info", None) if version is None: version = () conn.close() @@ -395,9 +396,7 @@ def _server_version(engine): def db_spec(*dbs): - return OrPredicate( - [Predicate.as_predicate(db) for db in dbs] - ) + return OrPredicate([Predicate.as_predicate(db) for db in dbs]) def open(): @@ -422,11 +421,7 @@ def fails_on(db, reason=None): def fails_on_everything_except(*dbs): - return succeeds_if( - OrPredicate([ - Predicate.as_predicate(db) for db in dbs - ]) - ) + return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) def skip(db, reason=None): @@ -435,8 +430,9 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([Predicate.as_predicate(db, reason) - for db in util.to_list(dbs)]) + OrPredicate( + [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)] + ) ) @@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None): def against(config, *queries): assert queries, "no queries sent!" - return OrPredicate([ - Predicate.as_predicate(query) - for query in queries - ])(config) + return OrPredicate([Predicate.as_predicate(query) for query in queries])( + config + ) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index dd0fa5a48..98184cdd4 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -54,19 +54,19 @@ class TestBase(object): class TablesTest(TestBase): # 'once', None - run_setup_bind = 'once' + run_setup_bind = "once" # 'once', 'each', None - run_define_tables = 'once' + run_define_tables = "once" # 'once', 'each', None - run_create_tables = 'once' + run_create_tables = "once" # 'once', 'each', None - run_inserts = 'each' + run_inserts = "each" # 'each', None - run_deletes = 'each' + run_deletes = "each" # 'once', None run_dispose_bind = None @@ -86,10 +86,10 @@ class TablesTest(TestBase): @classmethod def _init_class(cls): - if cls.run_define_tables == 'each': - if cls.run_create_tables == 'once': - cls.run_create_tables = 'each' - assert cls.run_inserts in ('each', None) + if cls.run_define_tables == "each": + if cls.run_create_tables == "once": + cls.run_create_tables = "each" + assert cls.run_inserts in ("each", None) cls.other = adict() cls.tables = adict() @@ -100,40 +100,40 @@ class TablesTest(TestBase): @classmethod def _setup_once_inserts(cls): - if cls.run_inserts == 'once': + if cls.run_inserts == "once": cls._load_fixtures() cls.insert_data() @classmethod def _setup_once_tables(cls): - if cls.run_define_tables == 'once': + if cls.run_define_tables == "once": cls.define_tables(cls.metadata) - if cls.run_create_tables == 'once': + if cls.run_create_tables == "once": cls.metadata.create_all(cls.bind) cls.tables.update(cls.metadata.tables) def _setup_each_tables(self): - if self.run_define_tables == 'each': + if self.run_define_tables == "each": self.tables.clear() - if self.run_create_tables == 'each': + if self.run_create_tables == "each": drop_all_tables(self.metadata, self.bind) self.metadata.clear() self.define_tables(self.metadata) - if self.run_create_tables == 'each': + if self.run_create_tables == "each": self.metadata.create_all(self.bind) self.tables.update(self.metadata.tables) - elif self.run_create_tables == 'each': + elif self.run_create_tables == "each": drop_all_tables(self.metadata, self.bind) self.metadata.create_all(self.bind) def _setup_each_inserts(self): - if self.run_inserts == 'each': + if self.run_inserts == "each": self._load_fixtures() self.insert_data() def _teardown_each_tables(self): # no need to run deletes if tables are recreated on setup - if self.run_define_tables != 'each' and self.run_deletes == 'each': + if self.run_define_tables != "each" and self.run_deletes == "each": with self.bind.connect() as conn: for table in reversed(self.metadata.sorted_tables): try: @@ -141,7 +141,8 @@ class TablesTest(TestBase): except sa.exc.DBAPIError as ex: util.print_( ("Error emptying table %s: %r" % (table, ex)), - file=sys.stderr) + file=sys.stderr, + ) def setup(self): self._setup_each_tables() @@ -155,7 +156,7 @@ class TablesTest(TestBase): if cls.run_create_tables: drop_all_tables(cls.metadata, cls.bind) - if cls.run_dispose_bind == 'once': + if cls.run_dispose_bind == "once": cls.dispose_bind(cls.bind) cls.metadata.bind = None @@ -173,9 +174,9 @@ class TablesTest(TestBase): @classmethod def dispose_bind(cls, bind): - if hasattr(bind, 'dispose'): + if hasattr(bind, "dispose"): bind.dispose() - elif hasattr(bind, 'close'): + elif hasattr(bind, "close"): bind.close() @classmethod @@ -212,8 +213,12 @@ class TablesTest(TestBase): continue cls.bind.execute( table.insert(), - [dict(zip(headers[table], column_values)) - for column_values in rows[table]]) + [ + dict(zip(headers[table], column_values)) + for column_values in rows[table] + ], + ) + from sqlalchemy import event @@ -236,7 +241,6 @@ class RemovesEvents(object): class _ORMTest(object): - @classmethod def teardown_class(cls): sa.orm.session.Session.close_all() @@ -249,10 +253,10 @@ class ORMTest(_ORMTest, TestBase): class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None - run_setup_classes = 'once' + run_setup_classes = "once" # 'once', 'each', None - run_setup_mappers = 'each' + run_setup_mappers = "each" classes = None @@ -292,20 +296,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): @classmethod def _setup_once_classes(cls): - if cls.run_setup_classes == 'once': + if cls.run_setup_classes == "once": cls._with_register_classes(cls.setup_classes) @classmethod def _setup_once_mappers(cls): - if cls.run_setup_mappers == 'once': + if cls.run_setup_mappers == "once": cls._with_register_classes(cls.setup_mappers) def _setup_each_mappers(self): - if self.run_setup_mappers == 'each': + if self.run_setup_mappers == "each": self._with_register_classes(self.setup_mappers) def _setup_each_classes(self): - if self.run_setup_classes == 'each': + if self.run_setup_classes == "each": self._with_register_classes(self.setup_classes) @classmethod @@ -339,11 +343,11 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # some tests create mappers in the test bodies # and will define setup_mappers as None - # clear mappers in any case - if self.run_setup_mappers != 'once': + if self.run_setup_mappers != "once": sa.orm.clear_mappers() def _teardown_each_classes(self): - if self.run_setup_classes != 'once': + if self.run_setup_classes != "once": self.classes.clear() @classmethod @@ -356,8 +360,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): class DeclarativeMappedTest(MappedTest): - run_setup_classes = 'once' - run_setup_mappers = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" @classmethod def _setup_once_tables(cls): @@ -370,15 +374,16 @@ class DeclarativeMappedTest(MappedTest): class FindFixtureDeclarative(DeclarativeMeta): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls - return DeclarativeMeta.__init__( - cls, classname, bases, dict_) + return DeclarativeMeta.__init__(cls, classname, bases, dict_) class DeclarativeBasic(object): __table_cls__ = schema.Table - _DeclBase = declarative_base(metadata=cls.metadata, - metaclass=FindFixtureDeclarative, - cls=DeclarativeBasic) + _DeclBase = declarative_base( + metadata=cls.metadata, + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic, + ) cls.DeclarativeBasic = _DeclBase fn() diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py index ea0a8da82..dc530af5e 100644 --- a/lib/sqlalchemy/testing/mock.py +++ b/lib/sqlalchemy/testing/mock.py @@ -18,4 +18,5 @@ else: except ImportError: raise ImportError( "SQLAlchemy's test suite requires the " - "'mock' library as of 0.8.2.") + "'mock' library as of 0.8.2." + ) diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index 087fc1fe6..e84cbde44 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -46,29 +46,28 @@ class Parent(fixtures.ComparableEntity): class Screen(object): - def __init__(self, obj, parent=None): self.obj = obj self.parent = parent class Foo(object): - def __init__(self, moredata): - self.data = 'im data' - self.stuff = 'im stuff' + self.data = "im data" + self.stuff = "im stuff" self.moredata = moredata __hash__ = object.__hash__ def __eq__(self, other): - return other.data == self.data and \ - other.stuff == self.stuff and \ - other.moredata == self.moredata + return ( + other.data == self.data + and other.stuff == self.stuff + and other.moredata == self.moredata + ) class Bar(object): - def __init__(self, x, y): self.x = x self.y = y @@ -76,35 +75,36 @@ class Bar(object): __hash__ = object.__hash__ def __eq__(self, other): - return other.__class__ is self.__class__ and \ - other.x == self.x and \ - other.y == self.y + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) def __str__(self): return "Bar(%d, %d)" % (self.x, self.y) class OldSchool: - def __init__(self, x, y): self.x = x self.y = y def __eq__(self, other): - return other.__class__ is self.__class__ and \ - other.x == self.x and \ - other.y == self.y + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) class OldSchoolWithoutCompare: - def __init__(self, x, y): self.x = x self.y = y class BarWithoutCompare(object): - def __init__(self, x, y): self.x = x self.y = y @@ -114,7 +114,6 @@ class BarWithoutCompare(object): class NotComparable(object): - def __init__(self, data): self.data = data @@ -129,7 +128,6 @@ class NotComparable(object): class BrokenComparable(object): - def __init__(self, data): self.data = data diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index 497fcb7e5..bb52c125c 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0. import os import sys -bootstrap_file = locals()['bootstrap_file'] -to_bootstrap = locals()['to_bootstrap'] +bootstrap_file = locals()["bootstrap_file"] +to_bootstrap = locals()["to_bootstrap"] def load_file_as_module(name): path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) if sys.version_info >= (3, 3): from importlib import machinery + mod = machinery.SourceFileLoader(name, path).load_module() else: import imp + mod = imp.load_source(name, path) return mod + if to_bootstrap == "pytest": sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py index 20ea61d89..0c28a5213 100644 --- a/lib/sqlalchemy/testing/plugin/noseplugin.py +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -25,6 +25,7 @@ import sys from nose.plugins import Plugin import nose + fixtures = None py3k = sys.version_info >= (3, 0) @@ -33,7 +34,7 @@ py3k = sys.version_info >= (3, 0) class NoseSQLAlchemy(Plugin): enabled = True - name = 'sqla_testing' + name = "sqla_testing" score = 100 def options(self, parser, env=os.environ): @@ -41,10 +42,14 @@ class NoseSQLAlchemy(Plugin): opt = parser.add_option def make_option(name, **kw): - callback_ = kw.pop("callback", None) or kw.pop("zeroarg_callback", None) + callback_ = kw.pop("callback", None) or kw.pop( + "zeroarg_callback", None + ) if callback_: + def wrap_(option, opt_str, value, parser): callback_(opt_str, value, parser) + kw["callback"] = wrap_ opt(name, **kw) @@ -73,7 +78,7 @@ class NoseSQLAlchemy(Plugin): def wantMethod(self, fn): if py3k: - if not hasattr(fn.__self__, 'cls'): + if not hasattr(fn.__self__, "cls"): return False cls = fn.__self__.cls else: @@ -84,24 +89,24 @@ class NoseSQLAlchemy(Plugin): return plugin_base.want_class(cls) def beforeTest(self, test): - if not hasattr(test.test, 'cls'): + if not hasattr(test.test, "cls"): return plugin_base.before_test( test, test.test.cls.__module__, - test.test.cls, test.test.method.__name__) + test.test.cls, + test.test.method.__name__, + ) def afterTest(self, test): plugin_base.after_test(test) def startContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.start_test_class(ctx) def stopContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.stop_test_class(ctx) diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 0ffcae093..5d6bf2975 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -46,58 +46,130 @@ options = None def setup_options(make_option): - make_option("--log-info", action="callback", type="string", callback=_log, - help="turn on info logging for <LOG> (multiple OK)") - make_option("--log-debug", action="callback", - type="string", callback=_log, - help="turn on debug logging for <LOG> (multiple OK)") - make_option("--db", action="append", type="string", dest="db", - help="Use prefab database uri. Multiple OK, " - "first one is run by default.") - make_option('--dbs', action='callback', zeroarg_callback=_list_dbs, - help="List available prefab dbs") - make_option("--dburi", action="append", type="string", dest="dburi", - help="Database uri. Multiple OK, " - "first one is run by default.") - make_option("--dropfirst", action="store_true", dest="dropfirst", - help="Drop all tables in the target database first") - make_option("--backend-only", action="store_true", dest="backend_only", - help="Run only tests marked with __backend__") - make_option("--nomemory", action="store_true", dest="nomemory", - help="Don't run memory profiling tests") - make_option("--postgresql-templatedb", type="string", - help="name of template database to use for PostgreSQL " - "CREATE DATABASE (defaults to current database)") - make_option("--low-connections", action="store_true", - dest="low_connections", - help="Use a low number of distinct connections - " - "i.e. for Oracle TNS") - make_option("--write-idents", type="string", dest="write_idents", - help="write out generated follower idents to <file>, " - "when -n<num> is used") - make_option("--reversetop", action="store_true", - dest="reversetop", default=False, - help="Use a random-ordering set implementation in the ORM " - "(helps reveal dependency issues)") - make_option("--requirements", action="callback", type="string", - callback=_requirements_opt, - help="requirements class for testing, overrides setup.cfg") - make_option("--with-cdecimal", action="store_true", - dest="cdecimal", default=False, - help="Monkeypatch the cdecimal library into Python 'decimal' " - "for all tests") - make_option("--include-tag", action="callback", callback=_include_tag, - type="string", - help="Include tests with tag <tag>") - make_option("--exclude-tag", action="callback", callback=_exclude_tag, - type="string", - help="Exclude tests with tag <tag>") - make_option("--write-profiles", action="store_true", - dest="write_profiles", default=False, - help="Write/update failing profiling data.") - make_option("--force-write-profiles", action="store_true", - dest="force_write_profiles", default=False, - help="Unconditionally write/update profiling data.") + make_option( + "--log-info", + action="callback", + type="string", + callback=_log, + help="turn on info logging for <LOG> (multiple OK)", + ) + make_option( + "--log-debug", + action="callback", + type="string", + callback=_log, + help="turn on debug logging for <LOG> (multiple OK)", + ) + make_option( + "--db", + action="append", + type="string", + dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.", + ) + make_option( + "--dbs", + action="callback", + zeroarg_callback=_list_dbs, + help="List available prefab dbs", + ) + make_option( + "--dburi", + action="append", + type="string", + dest="dburi", + help="Database uri. Multiple OK, " "first one is run by default.", + ) + make_option( + "--dropfirst", + action="store_true", + dest="dropfirst", + help="Drop all tables in the target database first", + ) + make_option( + "--backend-only", + action="store_true", + dest="backend_only", + help="Run only tests marked with __backend__", + ) + make_option( + "--nomemory", + action="store_true", + dest="nomemory", + help="Don't run memory profiling tests", + ) + make_option( + "--postgresql-templatedb", + type="string", + help="name of template database to use for PostgreSQL " + "CREATE DATABASE (defaults to current database)", + ) + make_option( + "--low-connections", + action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS", + ) + make_option( + "--write-idents", + type="string", + dest="write_idents", + help="write out generated follower idents to <file>, " + "when -n<num> is used", + ) + make_option( + "--reversetop", + action="store_true", + dest="reversetop", + default=False, + help="Use a random-ordering set implementation in the ORM " + "(helps reveal dependency issues)", + ) + make_option( + "--requirements", + action="callback", + type="string", + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg", + ) + make_option( + "--with-cdecimal", + action="store_true", + dest="cdecimal", + default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' " + "for all tests", + ) + make_option( + "--include-tag", + action="callback", + callback=_include_tag, + type="string", + help="Include tests with tag <tag>", + ) + make_option( + "--exclude-tag", + action="callback", + callback=_exclude_tag, + type="string", + help="Exclude tests with tag <tag>", + ) + make_option( + "--write-profiles", + action="store_true", + dest="write_profiles", + default=False, + help="Write/update failing profiling data.", + ) + make_option( + "--force-write-profiles", + action="store_true", + dest="force_write_profiles", + default=False, + help="Unconditionally write/update profiling data.", + ) def configure_follower(follower_ident): @@ -108,6 +180,7 @@ def configure_follower(follower_ident): """ from sqlalchemy.testing import provision + provision.FOLLOWER_IDENT = follower_ident @@ -121,9 +194,9 @@ def memoize_important_follower_config(dict_): callables, so we have to just copy all of that over. """ - dict_['memoized_config'] = { - 'include_tags': include_tags, - 'exclude_tags': exclude_tags + dict_["memoized_config"] = { + "include_tags": include_tags, + "exclude_tags": exclude_tags, } @@ -134,14 +207,14 @@ def restore_important_follower_config(dict_): """ global include_tags, exclude_tags - include_tags.update(dict_['memoized_config']['include_tags']) - exclude_tags.update(dict_['memoized_config']['exclude_tags']) + include_tags.update(dict_["memoized_config"]["include_tags"]) + exclude_tags.update(dict_["memoized_config"]["exclude_tags"]) def read_config(): global file_config file_config = configparser.ConfigParser() - file_config.read(['setup.cfg', 'test.cfg']) + file_config.read(["setup.cfg", "test.cfg"]) def pre_begin(opt): @@ -155,6 +228,7 @@ def pre_begin(opt): def set_coverage_flag(value): options.has_coverage = value + _skip_test_exception = None @@ -171,34 +245,33 @@ def post_begin(): # late imports, has to happen after config as well # as nose plugins like coverage - global util, fixtures, engines, exclusions, \ - assertions, warnings, profiling,\ - config, testing - from sqlalchemy import testing # noqa + global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing + from sqlalchemy import testing # noqa from sqlalchemy.testing import fixtures, engines, exclusions # noqa - from sqlalchemy.testing import assertions, warnings, profiling # noqa + from sqlalchemy.testing import assertions, warnings, profiling # noqa from sqlalchemy.testing import config # noqa from sqlalchemy import util # noqa - warnings.setup_filters() + warnings.setup_filters() def _log(opt_str, value, parser): global logging if not logging: import logging + logging.basicConfig() - if opt_str.endswith('-info'): + if opt_str.endswith("-info"): logging.getLogger(value).setLevel(logging.INFO) - elif opt_str.endswith('-debug'): + elif opt_str.endswith("-debug"): logging.getLogger(value).setLevel(logging.DEBUG) def _list_dbs(*args): print("Available --db options (use --dburi to override)") - for macro in sorted(file_config.options('db')): - print("%20s\t%s" % (macro, file_config.get('db', macro))) + for macro in sorted(file_config.options("db")): + print("%20s\t%s" % (macro, file_config.get("db", macro))) sys.exit(0) @@ -207,11 +280,12 @@ def _requirements_opt(opt_str, value, parser): def _exclude_tag(opt_str, value, parser): - exclude_tags.add(value.replace('-', '_')) + exclude_tags.add(value.replace("-", "_")) def _include_tag(opt_str, value, parser): - include_tags.add(value.replace('-', '_')) + include_tags.add(value.replace("-", "_")) + pre_configure = [] post_configure = [] @@ -243,7 +317,8 @@ def _set_nomemory(opt, file_config): def _monkeypatch_cdecimal(options, file_config): if options.cdecimal: import cdecimal - sys.modules['decimal'] = cdecimal + + sys.modules["decimal"] = cdecimal @post @@ -266,27 +341,28 @@ def _engine_uri(options, file_config): if options.db: for db_token in options.db: - for db in re.split(r'[,\s]+', db_token): - if db not in file_config.options('db'): + for db in re.split(r"[,\s]+", db_token): + if db not in file_config.options("db"): raise RuntimeError( "Unknown URI specifier '%s'. " - "Specify --dbs for known uris." - % db) + "Specify --dbs for known uris." % db + ) else: - db_urls.append(file_config.get('db', db)) + db_urls.append(file_config.get("db", db)) if not db_urls: - db_urls.append(file_config.get('db', 'default')) + db_urls.append(file_config.get("db", "default")) config._current = None for db_url in db_urls: - if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': + if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': with open(options.write_idents, "a") as file_: file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n") cfg = provision.setup_config( - db_url, options, file_config, provision.FOLLOWER_IDENT) + db_url, options, file_config, provision.FOLLOWER_IDENT + ) if not config._current: cfg.set_as_current(cfg, testing) @@ -295,7 +371,7 @@ def _engine_uri(options, file_config): @post def _requirements(options, file_config): - requirement_cls = file_config.get('sqla_testing', "requirement_cls") + requirement_cls = file_config.get("sqla_testing", "requirement_cls") _setup_requirements(requirement_cls) @@ -334,22 +410,28 @@ def _prep_testing_database(options, file_config): pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData()) - )) + e.execute( + schema._DropView( + schema.Table(vname, schema.MetaData()) + ) + ) if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names( - schema="test_schema") + view_names = inspector.get_view_names(schema="test_schema") except NotImplementedError: pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData(), - schema="test_schema") - )) + e.execute( + schema._DropView( + schema.Table( + vname, + schema.MetaData(), + schema="test_schema", + ) + ) + ) util.drop_all_tables(e, inspector) @@ -358,23 +440,29 @@ def _prep_testing_database(options, file_config): if against(cfg, "postgresql"): from sqlalchemy.dialects import postgresql + for enum in inspector.get_enums("*"): - e.execute(postgresql.DropEnumType( - postgresql.ENUM( - name=enum['name'], - schema=enum['schema']))) + e.execute( + postgresql.DropEnumType( + postgresql.ENUM( + name=enum["name"], schema=enum["schema"] + ) + ) + ) @post def _reverse_topological(options, file_config): if options.reversetop: from sqlalchemy.orm.util import randomize_unitofwork + randomize_unitofwork() @post def _post_setup_options(opt, file_config): from sqlalchemy.testing import config + config.options = options config.file_config = file_config @@ -382,17 +470,20 @@ def _post_setup_options(opt, file_config): @post def _setup_profiling(options, file_config): from sqlalchemy.testing import profiling + profiling._profile_stats = profiling.ProfileStatsFile( - file_config.get('sqla_testing', 'profile_file')) + file_config.get("sqla_testing", "profile_file") + ) def want_class(cls): if not issubclass(cls, fixtures.TestBase): return False - elif cls.__name__.startswith('_'): + elif cls.__name__.startswith("_"): return False - elif config.options.backend_only and not getattr(cls, '__backend__', - False): + elif config.options.backend_only and not getattr( + cls, "__backend__", False + ): return False else: return True @@ -405,25 +496,28 @@ def want_method(cls, fn): return False elif include_tags: return ( - hasattr(cls, '__tags__') and - exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) + hasattr(cls, "__tags__") + and exclusions.tags(cls.__tags__).include_test( + include_tags, exclude_tags + ) ) or ( - hasattr(fn, '_sa_exclusion_extend') and - fn._sa_exclusion_extend.include_test( - include_tags, exclude_tags) + hasattr(fn, "_sa_exclusion_extend") + and fn._sa_exclusion_extend.include_test( + include_tags, exclude_tags + ) ) - elif exclude_tags and hasattr(cls, '__tags__'): + elif exclude_tags and hasattr(cls, "__tags__"): return exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) - elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'): + include_tags, exclude_tags + ) + elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"): return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags) else: return True def generate_sub_tests(cls, module): - if getattr(cls, '__backend__', False): + if getattr(cls, "__backend__", False): for cfg in _possible_configs_for_cls(cls): orig_name = cls.__name__ @@ -431,16 +525,13 @@ def generate_sub_tests(cls, module): # pytest junit plugin, which is tripped up by the brackets # and periods, so sanitize - alpha_name = re.sub('[_\[\]\.]+', '_', cfg.name) - alpha_name = re.sub('_+$', '', alpha_name) + alpha_name = re.sub("[_\[\]\.]+", "_", cfg.name) + alpha_name = re.sub("_+$", "", alpha_name) name = "%s_%s" % (cls.__name__, alpha_name) subcls = type( name, - (cls, ), - { - "_sa_orig_cls_name": orig_name, - "__only_on_config__": cfg - } + (cls,), + {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg}, ) setattr(module, name, subcls) yield subcls @@ -454,8 +545,8 @@ def start_test_class(cls): def stop_test_class(cls): - #from sqlalchemy import inspect - #assert not inspect(testing.db).get_table_names() + # from sqlalchemy import inspect + # assert not inspect(testing.db).get_table_names() engines.testing_reaper._stop_test_ctx() try: if not options.low_connections: @@ -475,7 +566,7 @@ def final_process_cleanup(): def _setup_engine(cls): - if getattr(cls, '__engine_options__', None): + if getattr(cls, "__engine_options__", None): eng = engines.testing_engine(options=cls.__engine_options__) config._current.push_engine(eng, testing) @@ -485,7 +576,7 @@ def before_test(test, test_module_name, test_class, test_name): # like a nose id, e.g.: # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause" - name = getattr(test_class, '_sa_orig_cls_name', test_class.__name__) + name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__) id_ = "%s.%s.%s" % (test_module_name, name, test_name) @@ -505,16 +596,16 @@ def _possible_configs_for_cls(cls, reasons=None): if spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on__', None): + if getattr(cls, "__only_on__", None): spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) for config_obj in list(all_configs): if not spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on_config__', None): + if getattr(cls, "__only_on_config__", None): all_configs.intersection_update([cls.__only_on_config__]) - if hasattr(cls, '__requires__'): + if hasattr(cls, "__requires__"): requirements = config.requirements for config_obj in list(all_configs): for requirement in cls.__requires__: @@ -527,7 +618,7 @@ def _possible_configs_for_cls(cls, reasons=None): reasons.extend(skip_reasons) break - if hasattr(cls, '__prefer_requires__'): + if hasattr(cls, "__prefer_requires__"): non_preferred = set() requirements = config.requirements for config_obj in list(all_configs): @@ -546,30 +637,32 @@ def _do_skips(cls): reasons = [] all_configs = _possible_configs_for_cls(cls, reasons) - if getattr(cls, '__skip_if__', False): - for c in getattr(cls, '__skip_if__'): + if getattr(cls, "__skip_if__", False): + for c in getattr(cls, "__skip_if__"): if c(): - config.skip_test("'%s' skipped by %s" % ( - cls.__name__, c.__name__) + config.skip_test( + "'%s' skipped by %s" % (cls.__name__, c.__name__) ) if not all_configs: msg = "'%s' unsupported on any DB implementation %s%s" % ( cls.__name__, ", ".join( - "'%s(%s)+%s'" % ( + "'%s(%s)+%s'" + % ( config_obj.db.name, ".".join( - str(dig) for dig in - exclusions._server_version(config_obj.db)), - config_obj.db.driver + str(dig) + for dig in exclusions._server_version(config_obj.db) + ), + config_obj.db.driver, ) - for config_obj in config.Config.all_configs() + for config_obj in config.Config.all_configs() ), - ", ".join(reasons) + ", ".join(reasons), ) config.skip_test(msg) - elif hasattr(cls, '__prefer_backends__'): + elif hasattr(cls, "__prefer_backends__"): non_preferred = set() spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) for config_obj in all_configs: diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index da682ea00..fd0a48462 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -13,6 +13,7 @@ import os try: import xdist # noqa + has_xdist = True except ImportError: has_xdist = False @@ -24,30 +25,42 @@ def pytest_addoption(parser): def make_option(name, **kw): callback_ = kw.pop("callback", None) if callback_: + class CallableAction(argparse.Action): - def __call__(self, parser, namespace, - values, option_string=None): + def __call__( + self, parser, namespace, values, option_string=None + ): callback_(option_string, values, parser) + kw["action"] = CallableAction zeroarg_callback = kw.pop("zeroarg_callback", None) if zeroarg_callback: + class CallableAction(argparse.Action): - def __init__(self, option_strings, - dest, default=False, - required=False, help=None): - super(CallableAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=True, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, - values, option_string=None): + def __init__( + self, + option_strings, + dest, + default=False, + required=False, + help=None, + ): + super(CallableAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + ) + + def __call__( + self, parser, namespace, values, option_string=None + ): zeroarg_callback(option_string, values, parser) + kw["action"] = CallableAction group.addoption(name, **kw) @@ -59,18 +72,18 @@ def pytest_addoption(parser): def pytest_configure(config): if hasattr(config, "slaveinput"): plugin_base.restore_important_follower_config(config.slaveinput) - plugin_base.configure_follower( - config.slaveinput["follower_ident"] - ) + plugin_base.configure_follower(config.slaveinput["follower_ident"]) else: - if config.option.write_idents and \ - os.path.exists(config.option.write_idents): + if config.option.write_idents and os.path.exists( + config.option.write_idents + ): os.remove(config.option.write_idents) plugin_base.pre_begin(config.option) - plugin_base.set_coverage_flag(bool(getattr(config.option, - "cov_source", False))) + plugin_base.set_coverage_flag( + bool(getattr(config.option, "cov_source", False)) + ) plugin_base.set_skip_test(pytest.skip.Exception) @@ -94,10 +107,12 @@ if has_xdist: node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] from sqlalchemy.testing import provision + provision.create_follower_db(node.slaveinput["follower_ident"]) def pytest_testnodedown(node, error): from sqlalchemy.testing import provision + provision.drop_follower_db(node.slaveinput["follower_ident"]) @@ -114,19 +129,22 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict(list) items[:] = [ - item for item in - items if isinstance(item.parent, pytest.Instance) - and not item.parent.parent.name.startswith("_")] + item + for item in items + if isinstance(item.parent, pytest.Instance) + and not item.parent.parent.name.startswith("_") + ] test_classes = set(item.parent for item in items) for test_class in test_classes: for sub_cls in plugin_base.generate_sub_tests( - test_class.cls, test_class.parent.module): + test_class.cls, test_class.parent.module + ): if sub_cls is not test_class.cls: list_ = rebuilt_items[test_class.cls] for inst in pytest.Class( - sub_cls.__name__, - parent=test_class.parent.parent).collect(): + sub_cls.__name__, parent=test_class.parent.parent + ).collect(): list_.extend(inst.collect()) newitems = [] @@ -139,23 +157,29 @@ def pytest_collection_modifyitems(session, config, items): # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) - items[:] = sorted(newitems, key=lambda item: ( - item.parent.parent.parent.name, - item.parent.parent.name, - item.name - )) + items[:] = sorted( + newitems, + key=lambda item: ( + item.parent.parent.parent.name, + item.parent.parent.name, + item.name, + ), + ) def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(obj): return pytest.Class(name, parent=collector) - elif inspect.isfunction(obj) and \ - isinstance(collector, pytest.Instance) and \ - plugin_base.want_method(collector.cls, obj): + elif ( + inspect.isfunction(obj) + and isinstance(collector, pytest.Instance) + and plugin_base.want_method(collector.cls, obj) + ): return pytest.Function(name, parent=collector) else: return [] + _current_class = None @@ -180,6 +204,7 @@ def pytest_runtest_setup(item): global _current_class class_teardown(item.parent.parent) _current_class = None + item.parent.parent.addfinalizer(finalize) test_setup(item) @@ -194,8 +219,9 @@ def pytest_runtest_teardown(item): def test_setup(item): - plugin_base.before_test(item, item.parent.module.__name__, - item.parent.cls, item.name) + plugin_base.before_test( + item, item.parent.module.__name__, item.parent.cls, item.name + ) def test_teardown(item): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index fab99b186..3986985c7 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -42,17 +42,16 @@ class ProfileStatsFile(object): def __init__(self, filename): self.force_write = ( - config.options is not None and - config.options.force_write_profiles + config.options is not None and config.options.force_write_profiles ) self.write = self.force_write or ( - config.options is not None and - config.options.write_profiles + config.options is not None and config.options.write_profiles ) self.fname = os.path.abspath(filename) self.short_fname = os.path.split(self.fname)[-1] self.data = collections.defaultdict( - lambda: collections.defaultdict(dict)) + lambda: collections.defaultdict(dict) + ) self._read() if self.write: # rewrite for the case where features changed, @@ -65,7 +64,7 @@ class ProfileStatsFile(object): dbapi_key = config.db.name + "_" + config.db.driver # keep it at 2.7, 3.1, 3.2, etc. for now. - py_version = '.'.join([str(v) for v in sys.version_info[0:2]]) + py_version = ".".join([str(v) for v in sys.version_info[0:2]]) platform_tokens = [py_version] platform_tokens.append(dbapi_key) @@ -87,8 +86,7 @@ class ProfileStatsFile(object): def has_stats(self): test_key = _current_test return ( - test_key in self.data and - self.platform_key in self.data[test_key] + test_key in self.data and self.platform_key in self.data[test_key] ) def result(self, callcount): @@ -96,15 +94,15 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] - if 'counts' not in per_platform: - per_platform['counts'] = counts = [] + if "counts" not in per_platform: + per_platform["counts"] = counts = [] else: - counts = per_platform['counts'] + counts = per_platform["counts"] - if 'current_count' not in per_platform: - per_platform['current_count'] = current_count = 0 + if "current_count" not in per_platform: + per_platform["current_count"] = current_count = 0 else: - current_count = per_platform['current_count'] + current_count = per_platform["current_count"] has_count = len(counts) > current_count @@ -114,16 +112,16 @@ class ProfileStatsFile(object): self._write() result = None else: - result = per_platform['lineno'], counts[current_count] - per_platform['current_count'] += 1 + result = per_platform["lineno"], counts[current_count] + per_platform["current_count"] += 1 return result def replace(self, callcount): test_key = _current_test per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] - counts = per_platform['counts'] - current_count = per_platform['current_count'] + counts = per_platform["counts"] + current_count = per_platform["current_count"] if current_count < len(counts): counts[current_count - 1] = callcount else: @@ -164,9 +162,9 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[platform_key] c = [int(count) for count in counts.split(",")] - per_platform['counts'] = c - per_platform['lineno'] = lineno + 1 - per_platform['current_count'] = 0 + per_platform["counts"] = c + per_platform["lineno"] = lineno + 1 + per_platform["current_count"] = 0 profile_f.close() def _write(self): @@ -179,7 +177,7 @@ class ProfileStatsFile(object): profile_f.write("\n# TEST: %s\n\n" % test_key) for platform_key in sorted(per_fn): per_platform = per_fn[platform_key] - c = ",".join(str(count) for count in per_platform['counts']) + c = ",".join(str(count) for count in per_platform["counts"]) profile_f.write("%s %s %s\n" % (test_key, platform_key, c)) profile_f.close() @@ -199,7 +197,9 @@ def function_call_count(variance=0.05): def wrap(*args, **kw): with count_functions(variance=variance): return fn(*args, **kw) + return update_wrapper(wrap, fn) + return decorate @@ -213,21 +213,22 @@ def count_functions(variance=0.05): "No profiling stats available on this " "platform for this function. Run tests with " "--write-profiles to add statistics to %s for " - "this platform." % _profile_stats.short_fname) + "this platform." % _profile_stats.short_fname + ) gc_collect() pr = cProfile.Profile() pr.enable() - #began = time.time() + # began = time.time() yield - #ended = time.time() + # ended = time.time() pr.disable() - #s = compat.StringIO() + # s = compat.StringIO() stats = pstats.Stats(pr, stream=sys.stdout) - #timespent = ended - began + # timespent = ended - began callcount = stats.total_calls expected = _profile_stats.result(callcount) @@ -237,11 +238,7 @@ def count_functions(variance=0.05): else: line_no, expected_count = expected - print(("Pstats calls: %d Expected %s" % ( - callcount, - expected_count - ) - )) + print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) stats.sort_stats("cumulative") stats.print_stats() @@ -259,7 +256,9 @@ def count_functions(variance=0.05): "--write-profiles to " "regenerate this callcount." % ( - callcount, (variance * 100), - expected_count, _profile_stats.platform_key)) - - + callcount, + (variance * 100), + expected_count, + _profile_stats.platform_key, + ) + ) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index c0ca7c1cb..25028ccb3 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -8,6 +8,7 @@ import collections import os import time import logging + log = logging.getLogger(__name__) FOLLOWER_IDENT = None @@ -25,6 +26,7 @@ class register(object): def decorate(fn): self.fns[dbname] = fn return self + return decorate def __call__(self, cfg, *arg): @@ -38,7 +40,7 @@ class register(object): if backend in self.fns: return self.fns[backend](cfg, *arg) else: - return self.fns['*'](cfg, *arg) + return self.fns["*"](cfg, *arg) def create_follower_db(follower_ident): @@ -82,9 +84,7 @@ def _configs_for_db_operation(): for cfg in config.Config.all_configs(): url = cfg.db.url backend = url.get_backend_name() - host_conf = ( - backend, - url.username, url.host, url.database) + host_conf = (backend, url.username, url.host, url.database) if host_conf not in hosts: yield cfg @@ -128,14 +128,13 @@ def _follower_url_from_main(url, ident): @_update_db_opts.for_db("mssql") def _mssql_update_db_opts(db_url, db_opts): - db_opts['legacy_schema_aliasing'] = False - + db_opts["legacy_schema_aliasing"] = False @_follower_url_from_main.for_db("sqlite") def _sqlite_follower_url_from_main(url, ident): url = sa_url.make_url(url) - if not url.database or url.database == ':memory:': + if not url.database or url.database == ":memory:": return url else: return sa_url.make_url("sqlite:///%s.db" % ident) @@ -151,19 +150,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident): # as an attached if not follower_ident: dbapi_connection.execute( - 'ATTACH DATABASE "test_schema.db" AS test_schema') + 'ATTACH DATABASE "test_schema.db" AS test_schema' + ) else: dbapi_connection.execute( 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' - % follower_ident) + % follower_ident + ) @_create_db.for_db("postgresql") def _pg_create_db(cfg, eng, ident): template_db = cfg.options.postgresql_templatedb - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: try: _pg_drop_db(cfg, conn, ident) except Exception: @@ -175,7 +175,8 @@ def _pg_create_db(cfg, eng, ident): while True: try: conn.execute( - "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)) + "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db) + ) except exc.OperationalError as err: attempt += 1 if attempt >= 3: @@ -184,8 +185,11 @@ def _pg_create_db(cfg, eng, ident): log.info( "Waiting to create %s, URI %r, " "template DB %s is in use sleeping for .5", - ident, eng.url, template_db) - time.sleep(.5) + ident, + eng.url, + template_db, + ) + time.sleep(0.5) else: break @@ -203,9 +207,11 @@ def _mysql_create_db(cfg, eng, ident): # 1271, u"Illegal mix of collations for operation 'UNION'" conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident) conn.execute( - "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident) + "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident + ) conn.execute( - "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident) + "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident + ) @_configure_follower.for_db("mysql") @@ -221,14 +227,15 @@ def _sqlite_create_db(cfg, eng, ident): @_drop_db.for_db("postgresql") def _pg_drop_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: conn.execute( text( "select pg_terminate_backend(pid) from pg_stat_activity " "where usename=current_user and pid != pg_backend_pid() " "and datname=:dname" - ), dname=ident) + ), + dname=ident, + ) conn.execute("DROP DATABASE %s" % ident) @@ -257,11 +264,12 @@ def _oracle_create_db(cfg, eng, ident): conn.execute("create user %s identified by xe" % ident) conn.execute("create user %s_ts1 identified by xe" % ident) conn.execute("create user %s_ts2 identified by xe" % ident) - conn.execute("grant dba to %s" % (ident, )) + conn.execute("grant dba to %s" % (ident,)) conn.execute("grant unlimited tablespace to %s" % ident) conn.execute("grant unlimited tablespace to %s_ts1" % ident) conn.execute("grant unlimited tablespace to %s_ts2" % ident) + @_configure_follower.for_db("oracle") def _oracle_configure_follower(config, ident): config.test_schema = "%s_ts1" % ident @@ -320,6 +328,7 @@ def reap_dbs(idents_file): elif backend == "mssql": _reap_mssql_dbs(url, ident) + def _reap_oracle_dbs(url, idents): log.info("db reaper connecting to %r", url) eng = create_engine(url) @@ -330,8 +339,9 @@ def _reap_oracle_dbs(url, idents): to_reap = conn.execute( "select u.username from all_users u where username " "like 'TEST_%' and not exists (select username " - "from v$session where username=u.username)") - all_names = {username.lower() for (username, ) in to_reap} + "from v$session where username=u.username)" + ) + all_names = {username.lower() for (username,) in to_reap} to_drop = set() for name in all_names: if name.endswith("_ts1") or name.endswith("_ts2"): @@ -348,28 +358,28 @@ def _reap_oracle_dbs(url, idents): if _ora_drop_ignore(conn, username): dropped += 1 log.info( - "Dropped %d out of %d stale databases detected", - dropped, total) - + "Dropped %d out of %d stale databases detected", dropped, total + ) @_follower_url_from_main.for_db("oracle") def _oracle_follower_url_from_main(url, ident): url = sa_url.make_url(url) url.username = ident - url.password = 'xe' + url.password = "xe" return url @_create_db.for_db("mssql") def _mssql_create_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: conn.execute("create database %s" % ident) conn.execute( - "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident) + "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident + ) conn.execute( - "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident) + "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident + ) conn.execute("use %s" % ident) conn.execute("create schema test_schema") conn.execute("create schema test_schema_2") @@ -377,10 +387,10 @@ def _mssql_create_db(cfg, eng, ident): @_drop_db.for_db("mssql") def _mssql_drop_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: _mssql_drop_ignore(conn, ident) + def _mssql_drop_ignore(conn, ident): try: # typically when this happens, we can't KILL the session anyway, @@ -401,8 +411,7 @@ def _mssql_drop_ignore(conn, ident): def _reap_mssql_dbs(url, idents): log.info("db reaper connecting to %r", url) eng = create_engine(url) - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: log.info("identifiers in file: %s", ", ".join(idents)) @@ -410,8 +419,9 @@ def _reap_mssql_dbs(url, idents): "select d.name from sys.databases as d where name " "like 'TEST_%' and not exists (select session_id " "from sys.dm_exec_sessions " - "where database_id=d.database_id)") - all_names = {dbname.lower() for (dbname, ) in to_reap} + "where database_id=d.database_id)" + ) + all_names = {dbname.lower() for (dbname,) in to_reap} to_drop = set() for name in all_names: if name in idents: @@ -422,5 +432,5 @@ def _reap_mssql_dbs(url, idents): if _mssql_drop_ignore(conn, dbname): dropped += 1 log.info( - "Dropped %d out of %d stale databases detected", - dropped, total) + "Dropped %d out of %d stale databases detected", dropped, total + ) diff --git a/lib/sqlalchemy/testing/replay_fixture.py b/lib/sqlalchemy/testing/replay_fixture.py index b50f52e3d..9832b07a2 100644 --- a/lib/sqlalchemy/testing/replay_fixture.py +++ b/lib/sqlalchemy/testing/replay_fixture.py @@ -11,7 +11,6 @@ from sqlalchemy.orm import Session class ReplayFixtureTest(fixtures.TestBase): - @contextlib.contextmanager def _dummy_ctx(self, *arg, **kw): yield @@ -22,8 +21,8 @@ class ReplayFixtureTest(fixtures.TestBase): creator = config.db.pool._creator recorder = lambda: dbapi_session.recorder(creator()) engine = create_engine( - config.db.url, creator=recorder, - use_native_hstore=False) + config.db.url, creator=recorder, use_native_hstore=False + ) self.metadata = MetaData(engine) self.engine = engine self.session = Session(engine) @@ -37,8 +36,8 @@ class ReplayFixtureTest(fixtures.TestBase): player = lambda: dbapi_session.player() engine = create_engine( - config.db.url, creator=player, - use_native_hstore=False) + config.db.url, creator=player, use_native_hstore=False + ) self.metadata = MetaData(engine) self.engine = engine @@ -74,21 +73,49 @@ class ReplayableSession(object): NoAttribute = object() if util.py2k: - Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', 'UnboundMethodType',)]) + Natives = set( + [getattr(types, t) for t in dir(types) if not t.startswith("_")] + ).difference( + [ + getattr(types, t) + for t in ( + "FunctionType", + "BuiltinFunctionType", + "MethodType", + "BuiltinMethodType", + "LambdaType", + "UnboundMethodType", + ) + ] + ) else: - Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]).\ - union([type(t) if not isinstance(t, type) - else t for t in __builtins__.values()]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', )]) + Natives = ( + set( + [ + getattr(types, t) + for t in dir(types) + if not t.startswith("_") + ] + ) + .union( + [ + type(t) if not isinstance(t, type) else t + for t in __builtins__.values() + ] + ) + .difference( + [ + getattr(types, t) + for t in ( + "FunctionType", + "BuiltinFunctionType", + "MethodType", + "BuiltinMethodType", + "LambdaType", + ) + ] + ) + ) def __init__(self): self.buffer = deque() @@ -105,8 +132,10 @@ class ReplayableSession(object): self._subject = subject def __call__(self, *args, **kw): - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] + subject, buffer = [ + object.__getattribute__(self, x) + for x in ("_subject", "_buffer") + ] result = subject(*args, **kw) if type(result) not in ReplayableSession.Natives: @@ -126,8 +155,10 @@ class ReplayableSession(object): except AttributeError: pass - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] + subject, buffer = [ + object.__getattribute__(self, x) + for x in ("_subject", "_buffer") + ] try: result = type(subject).__getattribute__(subject, key) except AttributeError: @@ -146,7 +177,7 @@ class ReplayableSession(object): self._buffer = buffer def __call__(self, *args, **kw): - buffer = object.__getattribute__(self, '_buffer') + buffer = object.__getattribute__(self, "_buffer") result = buffer.popleft() if result is ReplayableSession.Callable: return self @@ -162,7 +193,7 @@ class ReplayableSession(object): return object.__getattribute__(self, key) except AttributeError: pass - buffer = object.__getattribute__(self, '_buffer') + buffer = object.__getattribute__(self, "_buffer") result = buffer.popleft() if result is ReplayableSession.Callable: return self diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 58df643f4..c96d26d32 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -26,7 +26,6 @@ class Requirements(object): class SuiteRequirements(Requirements): - @property def create_table(self): """target platform can emit basic CreateTable DDL.""" @@ -68,8 +67,8 @@ class SuiteRequirements(Requirements): # somehow only_if([x, y]) isn't working here, negation/conjunctions # getting confused. return exclusions.only_if( - lambda: self.on_update_cascade.enabled or - self.deferrable_fks.enabled + lambda: self.on_update_cascade.enabled + or self.deferrable_fks.enabled ) @property @@ -231,22 +230,21 @@ class SuiteRequirements(Requirements): def sane_rowcount(self): return exclusions.skip_if( lambda config: not config.db.dialect.supports_sane_rowcount, - "driver doesn't support 'sane' rowcount" + "driver doesn't support 'sane' rowcount", ) @property def sane_multi_rowcount(self): return exclusions.fails_if( lambda config: not config.db.dialect.supports_sane_multi_rowcount, - "driver %(driver)s %(doesnt_support)s 'sane' multi row count" + "driver %(driver)s %(doesnt_support)s 'sane' multi row count", ) @property def sane_rowcount_w_returning(self): return exclusions.fails_if( - lambda config: - not config.db.dialect.supports_sane_rowcount_returning, - "driver doesn't support 'sane' rowcount when returning is on" + lambda config: not config.db.dialect.supports_sane_rowcount_returning, + "driver doesn't support 'sane' rowcount when returning is on", ) @property @@ -255,9 +253,9 @@ class SuiteRequirements(Requirements): INSERT DEFAULT VALUES or equivalent.""" return exclusions.only_if( - lambda config: config.db.dialect.supports_empty_insert or - config.db.dialect.supports_default_values, - "empty inserts not supported" + lambda config: config.db.dialect.supports_empty_insert + or config.db.dialect.supports_default_values, + "empty inserts not supported", ) @property @@ -272,7 +270,7 @@ class SuiteRequirements(Requirements): return exclusions.only_if( lambda config: config.db.dialect.implicit_returning, - "%(database)s %(does_support)s 'returning'" + "%(database)s %(does_support)s 'returning'", ) @property @@ -297,7 +295,7 @@ class SuiteRequirements(Requirements): return exclusions.skip_if( lambda config: not config.db.dialect.requires_name_normalize, - "Backend does not require denormalized names." + "Backend does not require denormalized names.", ) @property @@ -307,7 +305,7 @@ class SuiteRequirements(Requirements): return exclusions.skip_if( lambda config: not config.db.dialect.supports_multivalues_insert, - "Backend does not support multirow inserts." + "Backend does not support multirow inserts.", ) @property @@ -355,27 +353,32 @@ class SuiteRequirements(Requirements): def server_side_cursors(self): """Target dialect must support server side cursors.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_server_side_cursors - ], "no server side cursors support") + return exclusions.only_if( + [lambda config: config.db.dialect.supports_server_side_cursors], + "no server side cursors support", + ) @property def sequences(self): """Target database must support SEQUENCEs.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences - ], "no sequence support") + return exclusions.only_if( + [lambda config: config.db.dialect.supports_sequences], + "no sequence support", + ) @property def sequences_optional(self): """Target database supports sequences, but also optionally as a means of generating new PK values.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences and - config.db.dialect.sequences_optional - ], "no sequence support, or sequences not optional") + return exclusions.only_if( + [ + lambda config: config.db.dialect.supports_sequences + and config.db.dialect.sequences_optional + ], + "no sequence support, or sequences not optional", + ) @property def reflects_pk_names(self): @@ -841,7 +844,8 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( - lambda config: config.options.low_connections) + lambda config: config.options.low_connections + ) @property def timing_intensive(self): @@ -859,37 +863,37 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( lambda config: util.py3k and config.options.has_coverage, - "Stability issues with coverage + py3k" + "Stability issues with coverage + py3k", ) @property def python2(self): return exclusions.skip_if( lambda: sys.version_info >= (3,), - "Python version 2.xx is required." + "Python version 2.xx is required.", ) @property def python3(self): return exclusions.skip_if( - lambda: sys.version_info < (3,), - "Python version 3.xx is required." + lambda: sys.version_info < (3,), "Python version 3.xx is required." ) @property def cpython(self): return exclusions.only_if( - lambda: util.cpython, - "cPython interpreter needed" + lambda: util.cpython, "cPython interpreter needed" ) @property def non_broken_pickle(self): from sqlalchemy.util import pickle + return exclusions.only_if( - lambda: not util.pypy and pickle.__name__ == 'cPickle' - or sys.version_info >= (3, 2), - "Needs cPickle+cPython or newer Python 3 pickle" + lambda: not util.pypy + and pickle.__name__ == "cPickle" + or sys.version_info >= (3, 2), + "Needs cPickle+cPython or newer Python 3 pickle", ) @property @@ -910,7 +914,7 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( lambda config: config.options.has_coverage, - "Issues observed when coverage is enabled" + "Issues observed when coverage is enabled", ) def _has_mysql_on_windows(self, config): @@ -931,8 +935,9 @@ class SuiteRequirements(Requirements): def _has_sqlite(self): from sqlalchemy import create_engine + try: - create_engine('sqlite://') + create_engine("sqlite://") return True except ImportError: return False @@ -940,6 +945,7 @@ class SuiteRequirements(Requirements): def _has_cextensions(self): try: from sqlalchemy import cresultproxy, cprocessors + return True except ImportError: return False diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py index 87be0749c..6aa820fd5 100644 --- a/lib/sqlalchemy/testing/runner.py +++ b/lib/sqlalchemy/testing/runner.py @@ -47,4 +47,4 @@ def setup_py_test(): to nose. """ - nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner']) + nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"]) diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 401c8cbb7..b345a9487 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -9,7 +9,7 @@ from . import exclusions from .. import schema, event from . import config -__all__ = 'Table', 'Column', +__all__ = "Table", "Column" table_options = {} @@ -17,30 +17,35 @@ table_options = {} def Table(*args, **kw): """A schema.Table wrapper/hook for dialect-specific tweaks.""" - test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')} + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} kw.update(table_options) - if exclusions.against(config._current, 'mysql'): - if 'mysql_engine' not in kw and 'mysql_type' not in kw and \ - "autoload_with" not in kw: - if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: - kw['mysql_engine'] = 'InnoDB' + if exclusions.against(config._current, "mysql"): + if ( + "mysql_engine" not in kw + and "mysql_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mysql_engine"] = "InnoDB" else: - kw['mysql_engine'] = 'MyISAM' + kw["mysql_engine"] = "MyISAM" # Apply some default cascading rules for self-referential foreign keys. # MySQL InnoDB has some issues around seleting self-refs too. - if exclusions.against(config._current, 'firebird'): + if exclusions.against(config._current, "firebird"): table_name = args[0] - unpack = (config.db.dialect. - identifier_preparer.unformat_identifiers) + unpack = config.db.dialect.identifier_preparer.unformat_identifiers # Only going after ForeignKeys in Columns. May need to # expand to ForeignKeyConstraint too. - fks = [fk - for col in args if isinstance(col, schema.Column) - for fk in col.foreign_keys] + fks = [ + fk + for col in args + if isinstance(col, schema.Column) + for fk in col.foreign_keys + ] for fk in fks: # root around in raw spec @@ -54,9 +59,9 @@ def Table(*args, **kw): name = unpack(ref)[0] if name == table_name: if fk.ondelete is None: - fk.ondelete = 'CASCADE' + fk.ondelete = "CASCADE" if fk.onupdate is None: - fk.onupdate = 'CASCADE' + fk.onupdate = "CASCADE" return schema.Table(*args, **kw) @@ -64,37 +69,46 @@ def Table(*args, **kw): def Column(*args, **kw): """A schema.Column wrapper/hook for dialect-specific tweaks.""" - test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')} + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} if not config.requirements.foreign_key_ddl.enabled_for_config(config): args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] col = schema.Column(*args, **kw) - if test_opts.get('test_needs_autoincrement', False) and \ - kw.get('primary_key', False): + if test_opts.get("test_needs_autoincrement", False) and kw.get( + "primary_key", False + ): if col.default is None and col.server_default is None: col.autoincrement = True # allow any test suite to pick up on this - col.info['test_needs_autoincrement'] = True + col.info["test_needs_autoincrement"] = True # hardcoded rule for firebird, oracle; this should # be moved out - if exclusions.against(config._current, 'firebird', 'oracle'): + if exclusions.against(config._current, "firebird", "oracle"): + def add_seq(c, tbl): c._init_items( - schema.Sequence(_truncate_name( - config.db.dialect, tbl.name + '_' + c.name + '_seq'), - optional=True) + schema.Sequence( + _truncate_name( + config.db.dialect, tbl.name + "_" + c.name + "_seq" + ), + optional=True, + ) ) - event.listen(col, 'after_parent_attach', add_seq, propagate=True) + + event.listen(col, "after_parent_attach", add_seq, propagate=True) return col def _truncate_name(dialect, name): if len(name) > dialect.max_identifier_length: - return name[0:max(dialect.max_identifier_length - 6, 0)] + \ - "_" + hex(hash(name) % 64)[2:] + return ( + name[0 : max(dialect.max_identifier_length - 6, 0)] + + "_" + + hex(hash(name) % 64)[2:] + ) else: return name diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index 748d9722d..a4e142c5a 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -1,4 +1,3 @@ - from sqlalchemy.testing.suite.test_cte import * from sqlalchemy.testing.suite.test_dialect import * from sqlalchemy.testing.suite.test_ddl import * diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py index cc72278e6..d2f35933b 100644 --- a/lib/sqlalchemy/testing/suite/test_cte.py +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -10,22 +10,28 @@ from ..schema import Table, Column class CTETest(fixtures.TablesTest): __backend__ = True - __requires__ = 'ctes', + __requires__ = ("ctes",) - run_inserts = 'each' - run_deletes = 'each' + run_inserts = "each" + run_deletes = "each" @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column("parent_id", ForeignKey("some_table.id"))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id")), + ) - Table("some_other_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column("parent_id", Integer)) + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) @classmethod def insert_data(cls): @@ -36,28 +42,33 @@ class CTETest(fixtures.TablesTest): {"id": 2, "data": "d2", "parent_id": 1}, {"id": 3, "data": "d3", "parent_id": 1}, {"id": 4, "data": "d4", "parent_id": 3}, - {"id": 5, "data": "d5", "parent_id": 3} - ] + {"id": 5, "data": "d5", "parent_id": 3}, + ], ) def test_select_nonrecursive_round_trip(self): some_table = self.tables.some_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) result = conn.execute( select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"])) ) - eq_(result.fetchall(), [("d4", )]) + eq_(result.fetchall(), [("d4",)]) def test_select_recursive_round_trip(self): some_table = self.tables.some_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"])).cte( - "some_cte", recursive=True) + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) cte_alias = cte.alias("c1") st1 = some_table.alias() @@ -67,12 +78,13 @@ class CTETest(fixtures.TablesTest): select([st1]).where(st1.c.id == cte_alias.c.parent_id) ) result = conn.execute( - select([cte.c.data]).where( - cte.c.data != "d2").order_by(cte.c.data.desc()) + select([cte.c.data]) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) ) eq_( result.fetchall(), - [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)] + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], ) def test_insert_from_select_round_trip(self): @@ -80,20 +92,21 @@ class CTETest(fixtures.TablesTest): some_other_table = self.tables.some_other_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.insert().from_select( - ["id", "data", "parent_id"], - select([cte]) + ["id", "data", "parent_id"], select([cte]) ) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)] + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], ) @testing.requires.ctes_with_update_delete @@ -105,27 +118,31 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( - some_other_table.update().values(parent_id=5).where( - some_other_table.c.data == cte.c.data - ) + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), [ - (1, "d1", None), (2, "d2", 5), - (3, "d3", 5), (4, "d4", 5), (5, "d5", 3) - ] + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], ) @testing.requires.ctes_with_update_delete @@ -137,14 +154,15 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.delete().where( some_other_table.c.data == cte.c.data @@ -154,9 +172,7 @@ class CTETest(fixtures.TablesTest): conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [ - (1, "d1", None), (5, "d5", 3) - ] + [(1, "d1", None), (5, "d5", 3)], ) @testing.requires.ctes_with_update_delete @@ -168,26 +184,26 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.delete().where( - some_other_table.c.data == - select([cte.c.data]).where( - cte.c.id == some_other_table.c.id) + some_other_table.c.data + == select([cte.c.data]).where( + cte.c.id == some_other_table.c.id + ) ) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [ - (1, "d1", None), (5, "d5", 3) - ] + [(1, "d1", None), (5, "d5", 3)], ) diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py index 1d8010c8a..7c44388d4 100644 --- a/lib/sqlalchemy/testing/suite/test_ddl.py +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -1,5 +1,3 @@ - - from .. import fixtures, config, util from ..config import requirements from ..assertions import eq_ @@ -11,55 +9,47 @@ class TableDDLTest(fixtures.TestBase): __backend__ = True def _simple_fixture(self): - return Table('test_table', self.metadata, - Column('id', Integer, primary_key=True, - autoincrement=False), - Column('data', String(50)) - ) + return Table( + "test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) def _underscore_fixture(self): - return Table('_test_table', self.metadata, - Column('id', Integer, primary_key=True, - autoincrement=False), - Column('_data', String(50)) - ) + return Table( + "_test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("_data", String(50)), + ) def _simple_roundtrip(self, table): with config.db.begin() as conn: - conn.execute(table.insert().values((1, 'some data'))) + conn.execute(table.insert().values((1, "some data"))) result = conn.execute(table.select()) - eq_( - result.first(), - (1, 'some data') - ) + eq_(result.first(), (1, "some data")) @requirements.create_table @util.provide_metadata def test_create_table(self): table = self._simple_fixture() - table.create( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) self._simple_roundtrip(table) @requirements.drop_table @util.provide_metadata def test_drop_table(self): table = self._simple_fixture() - table.create( - config.db, checkfirst=False - ) - table.drop( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) + table.drop(config.db, checkfirst=False) @requirements.create_table @util.provide_metadata def test_underscore_names(self): table = self._underscore_fixture() - table.create( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) self._simple_roundtrip(table) -__all__ = ('TableDDLTest', ) + +__all__ = ("TableDDLTest",) diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 2c5dd0e36..5e589f3b8 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -15,16 +15,19 @@ class ExceptionTest(fixtures.TablesTest): specific exceptions from real round trips, we need to be conservative. """ - run_deletes = 'each' + + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) @requirements.duplicate_key_raises_integrity_error def test_integrity_error(self): @@ -33,15 +36,14 @@ class ExceptionTest(fixtures.TablesTest): trans = conn.begin() conn.execute( - self.tables.manual_pk.insert(), - {'id': 1, 'data': 'd1'} + self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} ) assert_raises( exc.IntegrityError, conn.execute, self.tables.manual_pk.insert(), - {'id': 1, 'data': 'd1'} + {"id": 1, "data": "d1"}, ) trans.rollback() @@ -49,38 +51,39 @@ class ExceptionTest(fixtures.TablesTest): class AutocommitTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" - __requires__ = 'autocommit', + __requires__ = ("autocommit",) __backend__ = True @classmethod def define_tables(cls, metadata): - Table('some_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)), - test_needs_acid=True - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + test_needs_acid=True, + ) def _test_conn_autocommits(self, conn, autocommit): trans = conn.begin() conn.execute( - self.tables.some_table.insert(), - {"id": 1, "data": "some data"} + self.tables.some_table.insert(), {"id": 1, "data": "some data"} ) trans.rollback() eq_( conn.scalar(select([self.tables.some_table.c.id])), - 1 if autocommit else None + 1 if autocommit else None, ) conn.execute(self.tables.some_table.delete()) def test_autocommit_on(self): conn = config.db.connect() - c2 = conn.execution_options(isolation_level='AUTOCOMMIT') + c2 = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(c2, True) conn.invalidate() self._test_conn_autocommits(conn, False) @@ -98,7 +101,7 @@ class EscapingTest(fixtures.TestBase): """ m = self.metadata - t = Table('t', m, Column('data', String(50))) + t = Table("t", m, Column("data", String(50))) t.create(config.db) with config.db.begin() as conn: conn.execute(t.insert(), dict(data="some % value")) @@ -107,14 +110,17 @@ class EscapingTest(fixtures.TestBase): eq_( conn.scalar( select([t.c.data]).where( - t.c.data == literal_column("'some % value'")) + t.c.data == literal_column("'some % value'") + ) ), - "some % value" + "some % value", ) eq_( conn.scalar( select([t.c.data]).where( - t.c.data == literal_column("'some %% other value'")) - ), "some %% other value" + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", ) diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index c0b6b18eb..6257451eb 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -10,53 +10,48 @@ from ..schema import Table, Column class LastrowidTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True - __requires__ = 'implements_get_lastrowid', 'autoincrement_insert' + __requires__ = "implements_get_lastrowid", "autoincrement_insert" __engine_options__ = {"implicit_returning": False} @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (config.db.dialect.default_sequence_base, "some data") - ) + eq_(row, (config.db.dialect.default_sequence_base, "some data")) def test_autoincrement_on_insert(self): - config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.autoinc_pk.insert(), data="some data") self._assert_round_trip(self.tables.autoinc_pk, config.db) def test_last_inserted_id(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - r.inserted_primary_key, - [pk] - ) + eq_(r.inserted_primary_key, [pk]) # failed on pypy1.9 but seems to be OK on pypy 2.1 # @exclusions.fails_if(lambda: util.pypy, @@ -65,50 +60,57 @@ class LastrowidTest(fixtures.TablesTest): @requirements.dbapi_lastrowid def test_native_lastrowid_autoinc(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) lastrowid = r.lastrowid pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - lastrowid, pk - ) + eq_(lastrowid, pk) class InsertBehaviorTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) - Table('includes_defaults', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('x', Integer, default=5), - Column('y', Integer, - default=literal_column("2", type_=Integer) + literal(2))) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + Table( + "includes_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("x", Integer, default=5), + Column( + "y", + Integer, + default=literal_column("2", type_=Integer) + literal(2), + ), + ) def test_autoclose_on_insert(self): if requirements.returning.enabled: engine = engines.testing_engine( - options={'implicit_returning': False}) + options={"implicit_returning": False} + ) else: engine = config.db - r = engine.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + r = engine.execute(self.tables.autoinc_pk.insert(), data="some data") assert r._soft_closed assert not r.closed assert r.is_insert @@ -117,8 +119,7 @@ class InsertBehaviorTest(fixtures.TablesTest): @requirements.returning def test_autoclose_on_insert_implicit_returning(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) assert r._soft_closed assert not r.closed @@ -127,15 +128,14 @@ class InsertBehaviorTest(fixtures.TablesTest): @requirements.empty_inserts def test_empty_insert(self): - r = config.db.execute( - self.tables.autoinc_pk.insert(), - ) + r = config.db.execute(self.tables.autoinc_pk.insert()) assert r._soft_closed assert not r.closed r = config.db.execute( - self.tables.autoinc_pk.select(). - where(self.tables.autoinc_pk.c.id != None) + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) ) assert len(r.fetchall()) @@ -150,15 +150,15 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) result = config.db.execute( - dest_table.insert(). - from_select( + dest_table.insert().from_select( ("data",), - select([src_table.c.data]). - where(src_table.c.data.in_(["data2", "data3"])) + select([src_table.c.data]).where( + src_table.c.data.in_(["data2", "data3"]) + ), ) ) @@ -167,7 +167,7 @@ class InsertBehaviorTest(fixtures.TablesTest): result = config.db.execute( select([dest_table.c.data]).order_by(dest_table.c.data) ) - eq_(result.fetchall(), [("data2", ), ("data3", )]) + eq_(result.fetchall(), [("data2",), ("data3",)]) @requirements.insert_from_select def test_insert_from_select_autoinc_no_rows(self): @@ -175,11 +175,11 @@ class InsertBehaviorTest(fixtures.TablesTest): dest_table = self.tables.autoinc_pk result = config.db.execute( - dest_table.insert(). - from_select( + dest_table.insert().from_select( ("data",), - select([src_table.c.data]). - where(src_table.c.data.in_(["data2", "data3"])) + select([src_table.c.data]).where( + src_table.c.data.in_(["data2", "data3"]) + ), ) ) eq_(result.inserted_primary_key, [None]) @@ -199,23 +199,23 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) config.db.execute( - table.insert(inline=True). - from_select(("id", "data",), - select([table.c.id + 5, table.c.data]). - where(table.c.data.in_(["data2", "data3"])) - ), + table.insert(inline=True).from_select( + ("id", "data"), + select([table.c.id + 5, table.c.data]).where( + table.c.data.in_(["data2", "data3"]) + ), + ) ) eq_( config.db.execute( select([table.c.data]).order_by(table.c.data) ).fetchall(), - [("data1", ), ("data2", ), ("data2", ), - ("data3", ), ("data3", )] + [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)], ) @requirements.insert_from_select @@ -227,56 +227,60 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) config.db.execute( - table.insert(inline=True). - from_select(("id", "data",), - select([table.c.id + 5, table.c.data]). - where(table.c.data.in_(["data2", "data3"])) - ), + table.insert(inline=True).from_select( + ("id", "data"), + select([table.c.id + 5, table.c.data]).where( + table.c.data.in_(["data2", "data3"]) + ), + ) ) eq_( config.db.execute( select([table]).order_by(table.c.data, table.c.id) ).fetchall(), - [(1, 'data1', 5, 4), (2, 'data2', 5, 4), - (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)] + [ + (1, "data1", 5, 4), + (2, "data2", 5, 4), + (7, "data2", 5, 4), + (3, "data3", 5, 4), + (8, "data3", 5, 4), + ], ) class ReturningTest(fixtures.TablesTest): - run_create_tables = 'each' - __requires__ = 'returning', 'autoincrement_insert' + run_create_tables = "each" + __requires__ = "returning", "autoincrement_insert" __backend__ = True __engine_options__ = {"implicit_returning": True} def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (config.db.dialect.default_sequence_base, "some data") - ) + eq_(row, (config.db.dialect.default_sequence_base, "some data")) @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @requirements.fetch_rows_post_commit def test_explicit_returning_pk_autocommit(self): engine = config.db table = self.tables.autoinc_pk r = engine.execute( - table.insert().returning( - table.c.id), - data="some data" + table.insert().returning(table.c.id), data="some data" ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) @@ -287,9 +291,7 @@ class ReturningTest(fixtures.TablesTest): table = self.tables.autoinc_pk with engine.begin() as conn: r = conn.execute( - table.insert().returning( - table.c.id), - data="some data" + table.insert().returning(table.c.id), data="some data" ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) @@ -297,23 +299,16 @@ class ReturningTest(fixtures.TablesTest): def test_autoincrement_on_insert_implcit_returning(self): - config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.autoinc_pk.insert(), data="some data") self._assert_round_trip(self.tables.autoinc_pk, config.db) def test_last_inserted_id_implicit_returning(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - r.inserted_primary_key, - [pk] - ) + eq_(r.inserted_primary_key, [pk]) -__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest') +__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 00a5aac01..bfed5f1ab 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1,5 +1,3 @@ - - import sqlalchemy as sa from sqlalchemy import exc as sa_exc from sqlalchemy import types as sql_types @@ -26,10 +24,12 @@ class HasTableTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) def test_has_table(self): with config.db.begin() as conn: @@ -46,8 +46,10 @@ class ComponentReflectionTest(fixtures.TablesTest): def setup_bind(cls): if config.requirements.independent_connections.enabled: from sqlalchemy import pool + return engines.testing_engine( - options=dict(poolclass=pool.StaticPool)) + options=dict(poolclass=pool.StaticPool) + ) else: return config.db @@ -65,86 +67,109 @@ class ComponentReflectionTest(fixtures.TablesTest): schema_prefix = "" if testing.requires.self_referential_foreign_keys.enabled: - users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - Column('parent_user_id', sa.Integer, - sa.ForeignKey('%susers.user_id' % - schema_prefix, - name='user_id_fk')), - schema=schema, - test_needs_fk=True, - ) + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + Column( + "parent_user_id", + sa.Integer, + sa.ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ), + schema=schema, + test_needs_fk=True, + ) else: - users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - schema=schema, - test_needs_fk=True, - ) - - Table("dingalings", metadata, - Column('dingaling_id', sa.Integer, primary_key=True), - Column('address_id', sa.Integer, - sa.ForeignKey('%semail_addresses.address_id' % - schema_prefix)), - Column('data', sa.String(30)), - schema=schema, - test_needs_fk=True, - ) - Table('email_addresses', metadata, - Column('address_id', sa.Integer), - Column('remote_user_id', sa.Integer, - sa.ForeignKey(users.c.user_id)), - Column('email_address', sa.String(20)), - sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'), - schema=schema, - test_needs_fk=True, - ) - Table('comment_test', metadata, - Column('id', sa.Integer, primary_key=True, comment='id comment'), - Column('data', sa.String(20), comment='data % comment'), - Column( - 'd2', sa.String(20), - comment=r"""Comment types type speedily ' " \ '' Fun!"""), - schema=schema, - comment=r"""the test % ' " \ table comment""") + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + ), + Column("data", sa.String(30)), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sa.Integer), + Column( + "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) + ), + Column("email_address", sa.String(20)), + sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sa.Integer, primary_key=True, comment="id comment"), + Column("data", sa.String(20), comment="data % comment"), + Column( + "d2", + sa.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) if testing.requires.cross_schema_fk_reflection.enabled: if schema is None: Table( - 'local_table', metadata, - Column('id', sa.Integer, primary_key=True), - Column('data', sa.String(20)), + "local_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), Column( - 'remote_id', + "remote_id", ForeignKey( - '%s.remote_table_2.id' % - testing.config.test_schema) + "%s.remote_table_2.id" % testing.config.test_schema + ), ), test_needs_fk=True, - schema=config.db.dialect.default_schema_name + schema=config.db.dialect.default_schema_name, ) else: Table( - 'remote_table', metadata, - Column('id', sa.Integer, primary_key=True), + "remote_table", + metadata, + Column("id", sa.Integer, primary_key=True), Column( - 'local_id', + "local_id", ForeignKey( - '%s.local_table.id' % - config.db.dialect.default_schema_name) + "%s.local_table.id" + % config.db.dialect.default_schema_name + ), ), - Column('data', sa.String(20)), + Column("data", sa.String(20)), schema=schema, test_needs_fk=True, ) Table( - 'remote_table_2', metadata, - Column('id', sa.Integer, primary_key=True), - Column('data', sa.String(20)), + "remote_table_2", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), schema=schema, test_needs_fk=True, ) @@ -155,19 +180,21 @@ class ComponentReflectionTest(fixtures.TablesTest): if not schema: # test_needs_fk is at the moment to force MySQL InnoDB noncol_idx_test_nopk = Table( - 'noncol_idx_test_nopk', metadata, - Column('q', sa.String(5)), + "noncol_idx_test_nopk", + metadata, + Column("q", sa.String(5)), test_needs_fk=True, ) noncol_idx_test_pk = Table( - 'noncol_idx_test_pk', metadata, - Column('id', sa.Integer, primary_key=True), - Column('q', sa.String(5)), + "noncol_idx_test_pk", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("q", sa.String(5)), test_needs_fk=True, ) - Index('noncol_idx_nopk', noncol_idx_test_nopk.c.q.desc()) - Index('noncol_idx_pk', noncol_idx_test_pk.c.q.desc()) + Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) if testing.requires.view_column_reflection.enabled: cls.define_views(metadata, schema) @@ -180,34 +207,35 @@ class ComponentReflectionTest(fixtures.TablesTest): # temp table fixture if testing.against("oracle"): kw = { - 'prefixes': ["GLOBAL TEMPORARY"], - 'oracle_on_commit': 'PRESERVE ROWS' + "prefixes": ["GLOBAL TEMPORARY"], + "oracle_on_commit": "PRESERVE ROWS", } else: - kw = { - 'prefixes': ["TEMPORARY"], - } + kw = {"prefixes": ["TEMPORARY"]} user_tmp = Table( - "user_tmp", metadata, + "user_tmp", + metadata, Column("id", sa.INT, primary_key=True), - Column('name', sa.VARCHAR(50)), - Column('foo', sa.INT), - sa.UniqueConstraint('name', name='user_tmp_uq'), + Column("name", sa.VARCHAR(50)), + Column("foo", sa.INT), + sa.UniqueConstraint("name", name="user_tmp_uq"), sa.Index("user_tmp_ix", "foo"), **kw ) - if testing.requires.view_reflection.enabled and \ - testing.requires.temporary_views.enabled: - event.listen( - user_tmp, "after_create", - DDL("create temporary view user_tmp_v as " - "select * from user_tmp") - ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): event.listen( - user_tmp, "before_drop", - DDL("drop view user_tmp_v") + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp" + ), ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) @classmethod def define_index(cls, metadata, users): @@ -216,23 +244,19 @@ class ComponentReflectionTest(fixtures.TablesTest): @classmethod def define_views(cls, metadata, schema): - for table_name in ('users', 'email_addresses'): + for table_name in ("users", "email_addresses"): fullname = table_name if schema: fullname = "%s.%s" % (schema, table_name) - view_name = fullname + '_v' + view_name = fullname + "_v" query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, fullname) - - event.listen( - metadata, - "after_create", - DDL(query) + view_name, + fullname, ) + + event.listen(metadata, "after_create", DDL(query)) event.listen( - metadata, - "before_drop", - DDL("DROP VIEW %s" % view_name) + metadata, "before_drop", DDL("DROP VIEW %s" % view_name) ) @testing.requires.schema_reflection @@ -244,9 +268,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_dialect_initialize(self): engine = engines.testing_engine() - assert not hasattr(engine.dialect, 'default_schema_name') + assert not hasattr(engine.dialect, "default_schema_name") inspect(engine) - assert hasattr(engine.dialect, 'default_schema_name') + assert hasattr(engine.dialect, "default_schema_name") @testing.requires.schema_reflection def test_get_default_schema_name(self): @@ -254,40 +278,49 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) @testing.provide_metadata - def _test_get_table_names(self, schema=None, table_type='table', - order_by=None): + def _test_get_table_names( + self, schema=None, table_type="table", order_by=None + ): _ignore_tables = [ - 'comment_test', 'noncol_idx_test_pk', 'noncol_idx_test_nopk', - 'local_table', 'remote_table', 'remote_table_2' + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", ] meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) - if table_type == 'view': + if table_type == "view": table_names = insp.get_view_names(schema) table_names.sort() - answer = ['email_addresses_v', 'users_v'] + answer = ["email_addresses_v", "users_v"] eq_(sorted(table_names), answer) else: table_names = [ - t for t in insp.get_table_names( - schema, - order_by=order_by) if t not in _ignore_tables] + t + for t in insp.get_table_names(schema, order_by=order_by) + if t not in _ignore_tables + ] - if order_by == 'foreign_key': - answer = ['users', 'email_addresses', 'dingalings'] + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] eq_(table_names, answer) else: - answer = ['dingalings', 'email_addresses', 'users'] + answer = ["dingalings", "email_addresses", "users"] eq_(sorted(table_names), answer) @testing.requires.temp_table_names def test_get_temp_table_names(self): insp = inspect(self.bind) temp_table_names = insp.get_temp_table_names() - eq_(sorted(temp_table_names), ['user_tmp']) + eq_(sorted(temp_table_names), ["user_tmp"]) @testing.requires.view_reflection @testing.requires.temp_table_names @@ -295,7 +328,7 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_temp_view_names(self): insp = inspect(self.bind) temp_table_names = insp.get_temp_view_names() - eq_(sorted(temp_table_names), ['user_tmp_v']) + eq_(sorted(temp_table_names), ["user_tmp_v"]) @testing.requires.table_reflection def test_get_table_names(self): @@ -304,7 +337,7 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.table_reflection @testing.requires.foreign_key_constraint_reflection def test_get_table_names_fks(self): - self._test_get_table_names(order_by='foreign_key') + self._test_get_table_names(order_by="foreign_key") @testing.requires.comment_reflection def test_get_comments(self): @@ -320,26 +353,24 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_( insp.get_table_comment("comment_test", schema=schema), - {"text": r"""the test % ' " \ table comment"""} + {"text": r"""the test % ' " \ table comment"""}, ) - eq_( - insp.get_table_comment("users", schema=schema), - {"text": None} - ) + eq_(insp.get_table_comment("users", schema=schema), {"text": None}) eq_( [ - {"name": rec['name'], "comment": rec['comment']} - for rec in - insp.get_columns("comment_test", schema=schema) + {"name": rec["name"], "comment": rec["comment"]} + for rec in insp.get_columns("comment_test", schema=schema) ], [ - {'comment': 'id comment', 'name': 'id'}, - {'comment': 'data % comment', 'name': 'data'}, - {'comment': r"""Comment types type speedily ' " \ '' Fun!""", - 'name': 'd2'} - ] + {"comment": "id comment", "name": "id"}, + {"comment": "data % comment", "name": "data"}, + { + "comment": r"""Comment types type speedily ' " \ '' Fun!""", + "name": "d2", + }, + ], ) @testing.requires.table_reflection @@ -349,30 +380,33 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.view_column_reflection def test_get_view_names(self): - self._test_get_table_names(table_type='view') + self._test_get_table_names(table_type="view") @testing.requires.view_column_reflection @testing.requires.schemas def test_get_view_names_with_schema(self): self._test_get_table_names( - testing.config.test_schema, table_type='view') + testing.config.test_schema, table_type="view" + ) @testing.requires.table_reflection @testing.requires.view_column_reflection def test_get_tables_and_views(self): self._test_get_table_names() - self._test_get_table_names(table_type='view') + self._test_get_table_names(table_type="view") - def _test_get_columns(self, schema=None, table_type='table'): + def _test_get_columns(self, schema=None, table_type="table"): meta = MetaData(testing.db) - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings - table_names = ['users', 'email_addresses'] - if table_type == 'view': - table_names = ['users_v', 'email_addresses_v'] + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) + table_names = ["users", "email_addresses"] + if table_type == "view": + table_names = ["users_v", "email_addresses_v"] insp = inspect(meta.bind) - for table_name, table in zip(table_names, (users, - addresses)): + for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) self.assert_(len(cols) > 0, len(cols)) @@ -380,36 +414,46 @@ class ComponentReflectionTest(fixtures.TablesTest): # should be in order for i, col in enumerate(table.columns): - eq_(col.name, cols[i]['name']) - ctype = cols[i]['type'].__class__ + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ ctype_def = col.type if isinstance(ctype_def, sa.types.TypeEngine): ctype_def = ctype_def.__class__ # Oracle returns Date for DateTime. - if testing.against('oracle') and ctype_def \ - in (sql_types.Date, sql_types.DateTime): + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): ctype_def = sql_types.Date # assert that the desired type and return type share # a base within one of the generic types. - self.assert_(len(set(ctype.__mro__). - intersection(ctype_def.__mro__). - intersection([ - sql_types.Integer, - sql_types.Numeric, - sql_types.DateTime, - sql_types.Date, - sql_types.Time, - sql_types.String, - sql_types._Binary, - ])) > 0, '%s(%s), %s(%s)' % - (col.name, col.type, cols[i]['name'], ctype)) + self.assert_( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" + % (col.name, col.type, cols[i]["name"], ctype), + ) if not col.primary_key: - assert cols[i]['default'] is None + assert cols[i]["default"] is None @testing.requires.table_reflection def test_get_columns(self): @@ -417,24 +461,20 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _type_round_trip(self, *types): - t = Table('t', self.metadata, - *[ - Column('t%d' % i, type_) - for i, type_ in enumerate(types) - ] - ) + t = Table( + "t", + self.metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) t.create() return [ - c['type'] for c in - inspect(self.metadata.bind).get_columns('t') + c["type"] for c in inspect(self.metadata.bind).get_columns("t") ] @testing.requires.table_reflection def test_numeric_reflection(self): - for typ in self._type_round_trip( - sql_types.Numeric(18, 5), - ): + for typ in self._type_round_trip(sql_types.Numeric(18, 5)): assert isinstance(typ, sql_types.Numeric) eq_(typ.precision, 18) eq_(typ.scale, 5) @@ -448,16 +488,19 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.table_reflection @testing.provide_metadata def test_nullable_reflection(self): - t = Table('t', self.metadata, - Column('a', Integer, nullable=True), - Column('b', Integer, nullable=False)) + t = Table( + "t", + self.metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) t.create() eq_( dict( - (col['name'], col['nullable']) - for col in inspect(self.metadata.bind).get_columns('t') + (col["name"], col["nullable"]) + for col in inspect(self.metadata.bind).get_columns("t") ), - {"a": True, "b": False} + {"a": True, "b": False}, ) @testing.requires.table_reflection @@ -470,32 +513,30 @@ class ComponentReflectionTest(fixtures.TablesTest): meta = MetaData(self.bind) user_tmp = self.tables.user_tmp insp = inspect(meta.bind) - cols = insp.get_columns('user_tmp') + cols = insp.get_columns("user_tmp") self.assert_(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): - eq_(col.name, cols[i]['name']) + eq_(col.name, cols[i]["name"]) @testing.requires.temp_table_reflection @testing.requires.view_column_reflection @testing.requires.temporary_views def test_get_temp_view_columns(self): insp = inspect(self.bind) - cols = insp.get_columns('user_tmp_v') - eq_( - [col['name'] for col in cols], - ['id', 'name', 'foo'] - ) + cols = insp.get_columns("user_tmp_v") + eq_([col["name"] for col in cols], ["id", "name", "foo"]) @testing.requires.view_column_reflection def test_get_view_columns(self): - self._test_get_columns(table_type='view') + self._test_get_columns(table_type="view") @testing.requires.view_column_reflection @testing.requires.schemas def test_get_view_columns_with_schema(self): self._test_get_columns( - schema=testing.config.test_schema, table_type='view') + schema=testing.config.test_schema, table_type="view" + ) @testing.provide_metadata def _test_get_pk_constraint(self, schema=None): @@ -504,15 +545,15 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(meta.bind) users_cons = insp.get_pk_constraint(users.name, schema=schema) - users_pkeys = users_cons['constrained_columns'] - eq_(users_pkeys, ['user_id']) + users_pkeys = users_cons["constrained_columns"] + eq_(users_pkeys, ["user_id"]) addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) - addr_pkeys = addr_cons['constrained_columns'] - eq_(addr_pkeys, ['address_id']) + addr_pkeys = addr_cons["constrained_columns"] + eq_(addr_pkeys, ["address_id"]) with testing.requires.reflects_pk_names.fail_if(): - eq_(addr_cons['name'], 'email_ad_pk') + eq_(addr_cons["name"], "email_ad_pk") @testing.requires.primary_key_constraint_reflection def test_get_pk_constraint(self): @@ -534,44 +575,46 @@ class ComponentReflectionTest(fixtures.TablesTest): sa_exc.SADeprecationWarning, "Call to deprecated method get_primary_keys." " Use get_pk_constraint instead.", - insp.get_primary_keys, users.name + insp.get_primary_keys, + users.name, ) @testing.provide_metadata def _test_get_foreign_keys(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) expected_schema = schema # users if testing.requires.self_referential_foreign_keys.enabled: - users_fkeys = insp.get_foreign_keys(users.name, - schema=schema) + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) fkey1 = users_fkeys[0] with testing.requires.named_constraints.fail_if(): - eq_(fkey1['name'], "user_id_fk") + eq_(fkey1["name"], "user_id_fk") - eq_(fkey1['referred_schema'], expected_schema) - eq_(fkey1['referred_table'], users.name) - eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) if testing.requires.self_referential_foreign_keys.enabled: - eq_(fkey1['constrained_columns'], ['parent_user_id']) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) # addresses - addr_fkeys = insp.get_foreign_keys(addresses.name, - schema=schema) + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] with testing.requires.implicitly_named_constraints.fail_if(): - self.assert_(fkey1['name'] is not None) + self.assert_(fkey1["name"] is not None) - eq_(fkey1['referred_schema'], expected_schema) - eq_(fkey1['referred_table'], users.name) - eq_(fkey1['referred_columns'], ['user_id', ]) - eq_(fkey1['constrained_columns'], ['remote_user_id']) + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) @testing.requires.foreign_key_constraint_reflection def test_get_foreign_keys(self): @@ -586,9 +629,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schemas def test_get_inter_schema_foreign_keys(self): local_table, remote_table, remote_table_2 = self.tables( - '%s.local_table' % testing.db.dialect.default_schema_name, - '%s.remote_table' % testing.config.test_schema, - '%s.remote_table_2' % testing.config.test_schema + "%s.local_table" % testing.db.dialect.default_schema_name, + "%s.remote_table" % testing.config.test_schema, + "%s.remote_table_2" % testing.config.test_schema, ) insp = inspect(config.db) @@ -597,25 +640,25 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(len(local_fkeys), 1) fkey1 = local_fkeys[0] - eq_(fkey1['referred_schema'], testing.config.test_schema) - eq_(fkey1['referred_table'], remote_table_2.name) - eq_(fkey1['referred_columns'], ['id', ]) - eq_(fkey1['constrained_columns'], ['remote_id']) + eq_(fkey1["referred_schema"], testing.config.test_schema) + eq_(fkey1["referred_table"], remote_table_2.name) + eq_(fkey1["referred_columns"], ["id"]) + eq_(fkey1["constrained_columns"], ["remote_id"]) remote_fkeys = insp.get_foreign_keys( - remote_table.name, schema=testing.config.test_schema) + remote_table.name, schema=testing.config.test_schema + ) eq_(len(remote_fkeys), 1) fkey2 = remote_fkeys[0] - assert fkey2['referred_schema'] in ( + assert fkey2["referred_schema"] in ( None, - testing.db.dialect.default_schema_name + testing.db.dialect.default_schema_name, ) - eq_(fkey2['referred_table'], local_table.name) - eq_(fkey2['referred_columns'], ['id', ]) - eq_(fkey2['constrained_columns'], ['local_id']) - + eq_(fkey2["referred_table"], local_table.name) + eq_(fkey2["referred_columns"], ["id"]) + eq_(fkey2["constrained_columns"], ["local_id"]) @testing.requires.foreign_key_constraint_option_reflection_ondelete def test_get_foreign_key_options_ondelete(self): @@ -630,26 +673,32 @@ class ComponentReflectionTest(fixtures.TablesTest): meta = self.metadata Table( - 'x', meta, - Column('id', Integer, primary_key=True), - test_needs_fk=True - ) - - Table('table', meta, - Column('id', Integer, primary_key=True), - Column('x_id', Integer, sa.ForeignKey('x.id', name='xid')), - Column('test', String(10)), - test_needs_fk=True) - - Table('user', meta, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('tid', Integer), - sa.ForeignKeyConstraint( - ['tid'], ['table.id'], - name='myfk', - **options), - test_needs_fk=True) + "x", + meta, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) + + Table( + "table", + meta, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) meta.create_all() @@ -657,49 +706,44 @@ class ComponentReflectionTest(fixtures.TablesTest): # test 'options' is always present for a backend # that can reflect these, since alembic looks for this - opts = insp.get_foreign_keys('table')[0]['options'] + opts = insp.get_foreign_keys("table")[0]["options"] - eq_( - dict( - (k, opts[k]) - for k in opts if opts[k] - ), - {} - ) + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) - opts = insp.get_foreign_keys('user')[0]['options'] - eq_( - dict( - (k, opts[k]) - for k in opts if opts[k] - ), - options - ) + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(dict((k, opts[k]) for k in opts if opts[k]), options) def _assert_insp_indexes(self, indexes, expected_indexes): - index_names = [d['name'] for d in indexes] + index_names = [d["name"] for d in indexes] for e_index in expected_indexes: - assert e_index['name'] in index_names - index = indexes[index_names.index(e_index['name'])] + assert e_index["name"] in index_names + index = indexes[index_names.index(e_index["name"])] for key in e_index: eq_(e_index[key], index[key]) @testing.provide_metadata def _test_get_indexes(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. insp = inspect(meta.bind) - indexes = insp.get_indexes('users', schema=schema) + indexes = insp.get_indexes("users", schema=schema) expected_indexes = [ - {'unique': False, - 'column_names': ['test1', 'test2'], - 'name': 'users_t_idx'}, - {'unique': False, - 'column_names': ['user_id', 'test2', 'test1'], - 'name': 'users_all_idx'} + { + "unique": False, + "column_names": ["test1", "test2"], + "name": "users_t_idx", + }, + { + "unique": False, + "column_names": ["user_id", "test2", "test1"], + "name": "users_all_idx", + }, ] self._assert_insp_indexes(indexes, expected_indexes) @@ -721,10 +765,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # reflecting an index that has "x DESC" in it as the column. # the DB may or may not give us "x", but make sure we get the index # back, it has a name, it's connected to the table. - expected_indexes = [ - {'unique': False, - 'name': ixname} - ] + expected_indexes = [{"unique": False, "name": ixname}] self._assert_insp_indexes(indexes, expected_indexes) t = Table(tname, meta, autoload_with=meta.bind) @@ -748,24 +789,30 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.unique_constraint_reflection def test_get_temp_table_unique_constraints(self): insp = inspect(self.bind) - reflected = insp.get_unique_constraints('user_tmp') + reflected = insp.get_unique_constraints("user_tmp") for refl in reflected: # Different dialects handle duplicate index and constraints # differently, so ignore this flag - refl.pop('duplicates_index', None) - eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}]) + refl.pop("duplicates_index", None) + eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}]) @testing.requires.temp_table_reflection def test_get_temp_table_indexes(self): insp = inspect(self.bind) - indexes = insp.get_indexes('user_tmp') + indexes = insp.get_indexes("user_tmp") for ind in indexes: - ind.pop('dialect_options', None) + ind.pop("dialect_options", None) eq_( # TODO: we need to add better filtering for indexes/uq constraints # that are doubled up - [idx for idx in indexes if idx['name'] == 'user_tmp_ix'], - [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}] + [idx for idx in indexes if idx["name"] == "user_tmp_ix"], + [ + { + "unique": False, + "column_names": ["foo"], + "name": "user_tmp_ix", + } + ], ) @testing.requires.unique_constraint_reflection @@ -783,36 +830,37 @@ class ComponentReflectionTest(fixtures.TablesTest): # CREATE TABLE? uniques = sorted( [ - {'name': 'unique_a', 'column_names': ['a']}, - {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']}, - {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']}, - {'name': 'unique_asc_key', 'column_names': ['asc', 'key']}, - {'name': 'i.have.dots', 'column_names': ['b']}, - {'name': 'i have spaces', 'column_names': ['c']}, + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, ], - key=operator.itemgetter('name') + key=operator.itemgetter("name"), ) orig_meta = self.metadata table = Table( - 'testtbl', orig_meta, - Column('a', sa.String(20)), - Column('b', sa.String(30)), - Column('c', sa.Integer), + "testtbl", + orig_meta, + Column("a", sa.String(20)), + Column("b", sa.String(30)), + Column("c", sa.Integer), # reserved identifiers - Column('asc', sa.String(30)), - Column('key', sa.String(30)), - schema=schema + Column("asc", sa.String(30)), + Column("key", sa.String(30)), + schema=schema, ) for uc in uniques: table.append_constraint( - sa.UniqueConstraint(*uc['column_names'], name=uc['name']) + sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) ) orig_meta.create_all() inspector = inspect(orig_meta.bind) reflected = sorted( - inspector.get_unique_constraints('testtbl', schema=schema), - key=operator.itemgetter('name') + inspector.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), ) names_that_duplicate_index = set() @@ -820,25 +868,31 @@ class ComponentReflectionTest(fixtures.TablesTest): for orig, refl in zip(uniques, reflected): # Different dialects handle duplicate index and constraints # differently, so ignore this flag - dupe = refl.pop('duplicates_index', None) + dupe = refl.pop("duplicates_index", None) if dupe: names_that_duplicate_index.add(dupe) eq_(orig, refl) reflected_metadata = MetaData() reflected = Table( - 'testtbl', reflected_metadata, autoload_with=orig_meta.bind, - schema=schema) + "testtbl", + reflected_metadata, + autoload_with=orig_meta.bind, + schema=schema, + ) # test "deduplicates for index" logic. MySQL and Oracle # "unique constraints" are actually unique indexes (with possible # exception of a unique that is a dupe of another one in the case # of Oracle). make sure # they aren't duplicated. idx_names = set([idx.name for idx in reflected.indexes]) - uq_names = set([ - uq.name for uq in reflected.constraints - if isinstance(uq, sa.UniqueConstraint)]).difference( - ['unique_c_a_b']) + uq_names = set( + [ + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + ] + ).difference(["unique_c_a_b"]) assert not idx_names.intersection(uq_names) if names_that_duplicate_index: @@ -858,47 +912,52 @@ class ComponentReflectionTest(fixtures.TablesTest): def _test_get_check_constraints(self, schema=None): orig_meta = self.metadata Table( - 'sa_cc', orig_meta, - Column('a', Integer()), - sa.CheckConstraint('a > 1 AND a < 5', name='cc1'), - sa.CheckConstraint('a = 1 OR (a > 2 AND a < 5)', name='cc2'), - schema=schema + "sa_cc", + orig_meta, + Column("a", Integer()), + sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), + sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"), + schema=schema, ) orig_meta.create_all() inspector = inspect(orig_meta.bind) reflected = sorted( - inspector.get_check_constraints('sa_cc', schema=schema), - key=operator.itemgetter('name') + inspector.get_check_constraints("sa_cc", schema=schema), + key=operator.itemgetter("name"), ) # trying to minimize effect of quoting, parenthesis, etc. # may need to add more to this as new dialects get CHECK # constraint reflection support def normalize(sqltext): - return " ".join(re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)) + return " ".join( + re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + ) reflected = [ - {"name": item["name"], - "sqltext": normalize(item["sqltext"])} + {"name": item["name"], "sqltext": normalize(item["sqltext"])} for item in reflected ] eq_( reflected, [ - {'name': 'cc1', 'sqltext': 'a > 1 and a < 5'}, - {'name': 'cc2', 'sqltext': 'a = 1 or a > 2 and a < 5'} - ] + {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + {"name": "cc2", "sqltext": "a = 1 or a > 2 and a < 5"}, + ], ) @testing.provide_metadata def _test_get_view_definition(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings - view_name1 = 'users_v' - view_name2 = 'email_addresses_v' + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) + view_name1 = "users_v" + view_name2 = "email_addresses_v" insp = inspect(meta.bind) v1 = insp.get_view_definition(view_name1, schema=schema) self.assert_(v1) @@ -918,18 +977,21 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _test_get_table_oid(self, table_name, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) oid = insp.get_table_oid(table_name, schema) self.assert_(isinstance(oid, int)) def test_get_table_oid(self): - self._test_get_table_oid('users') + self._test_get_table_oid("users") @testing.requires.schemas def test_get_table_oid_with_schema(self): - self._test_get_table_oid('users', schema=testing.config.test_schema) + self._test_get_table_oid("users", schema=testing.config.test_schema) @testing.requires.table_reflection @testing.provide_metadata @@ -950,49 +1012,53 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(meta.bind) for tname, cname in [ - ('users', 'user_id'), - ('email_addresses', 'address_id'), - ('dingalings', 'dingaling_id'), + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), ]: cols = insp.get_columns(tname) - id_ = {c['name']: c for c in cols}[cname] - assert id_.get('autoincrement', True) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) class NormalizedNameTest(fixtures.TablesTest): - __requires__ = 'denormalized_names', + __requires__ = ("denormalized_names",) __backend__ = True @classmethod def define_tables(cls, metadata): Table( - quoted_name('t1', quote=True), metadata, - Column('id', Integer, primary_key=True), + quoted_name("t1", quote=True), + metadata, + Column("id", Integer, primary_key=True), ) Table( - quoted_name('t2', quote=True), metadata, - Column('id', Integer, primary_key=True), - Column('t1id', ForeignKey('t1.id')) + quoted_name("t2", quote=True), + metadata, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), ) def test_reflect_lowercase_forced_tables(self): m2 = MetaData(testing.db) - t2_ref = Table(quoted_name('t2', quote=True), m2, autoload=True) - t1_ref = m2.tables['t1'] + t2_ref = Table(quoted_name("t2", quote=True), m2, autoload=True) + t1_ref = m2.tables["t1"] assert t2_ref.c.t1id.references(t1_ref.c.id) m3 = MetaData(testing.db) - m3.reflect(only=lambda name, m: name.lower() in ('t1', 't2')) - assert m3.tables['t2'].c.t1id.references(m3.tables['t1'].c.id) + m3.reflect(only=lambda name, m: name.lower() in ("t1", "t2")) + assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) def test_get_table_names(self): tablenames = [ - t for t in inspect(testing.db).get_table_names() - if t.lower() in ("t1", "t2")] + t + for t in inspect(testing.db).get_table_names() + if t.lower() in ("t1", "t2") + ] eq_(tablenames[0].upper(), tablenames[0].lower()) eq_(tablenames[1].upper(), tablenames[1].lower()) -__all__ = ('ComponentReflectionTest', 'HasTableTest', 'NormalizedNameTest') +__all__ = ("ComponentReflectionTest", "HasTableTest", "NormalizedNameTest") diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index f464d47eb..247f05cf5 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -15,14 +15,18 @@ class RowFetchTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) - Table('has_dates', metadata, - Column('id', Integer, primary_key=True), - Column('today', DateTime) - ) + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "has_dates", + metadata, + Column("id", Integer, primary_key=True), + Column("today", DateTime), + ) @classmethod def insert_data(cls): @@ -32,65 +36,51 @@ class RowFetchTest(fixtures.TablesTest): {"id": 1, "data": "d1"}, {"id": 2, "data": "d2"}, {"id": 3, "data": "d3"}, - ] + ], ) config.db.execute( cls.tables.has_dates.insert(), - [ - {"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)} - ] + [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}], ) def test_via_string(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row['id'], 1 - ) - eq_( - row['data'], "d1" - ) + eq_(row["id"], 1) + eq_(row["data"], "d1") def test_via_int(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row[0], 1 - ) - eq_( - row[1], "d1" - ) + eq_(row[0], 1) + eq_(row[1], "d1") def test_via_col_object(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row[self.tables.plain_pk.c.id], 1 - ) - eq_( - row[self.tables.plain_pk.c.data], "d1" - ) + eq_(row[self.tables.plain_pk.c.id], 1) + eq_(row[self.tables.plain_pk.c.data], "d1") @requirements.duplicate_names_in_cursor_description def test_row_with_dupe_names(self): result = config.db.execute( - select([self.tables.plain_pk.c.data, - self.tables.plain_pk.c.data.label('data')]). - order_by(self.tables.plain_pk.c.id) + select( + [ + self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label("data"), + ] + ).order_by(self.tables.plain_pk.c.id) ) row = result.first() - eq_(result.keys(), ['data', 'data']) - eq_(row, ('d1', 'd1')) + eq_(result.keys(), ["data", "data"]) + eq_(row, ("d1", "d1")) def test_row_w_scalar_select(self): """test that a scalar select as a column is returned as such @@ -101,11 +91,11 @@ class RowFetchTest(fixtures.TablesTest): """ datetable = self.tables.has_dates - s = select([datetable.alias('x').c.today]).as_scalar() - s2 = select([datetable.c.id, s.label('somelabel')]) + s = select([datetable.alias("x").c.today]).as_scalar() + s2 = select([datetable.c.id, s.label("somelabel")]) row = config.db.execute(s2).first() - eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0)) + eq_(row["somelabel"], datetime.datetime(2006, 5, 12, 12, 0, 0)) class PercentSchemaNamesTest(fixtures.TablesTest): @@ -117,29 +107,31 @@ class PercentSchemaNamesTest(fixtures.TablesTest): """ - __requires__ = ('percent_schema_names', ) + __requires__ = ("percent_schema_names",) __backend__ = True @classmethod def define_tables(cls, metadata): - cls.tables.percent_table = Table('percent%table', metadata, - Column("percent%", Integer), - Column( - "spaces % more spaces", Integer), - ) + cls.tables.percent_table = Table( + "percent%table", + metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) cls.tables.lightweight_percent_table = sql.table( - 'percent%table', sql.column("percent%"), - sql.column("spaces % more spaces") + "percent%table", + sql.column("percent%"), + sql.column("spaces % more spaces"), ) def test_single_roundtrip(self): percent_table = self.tables.percent_table for params in [ - {'percent%': 5, 'spaces % more spaces': 12}, - {'percent%': 7, 'spaces % more spaces': 11}, - {'percent%': 9, 'spaces % more spaces': 10}, - {'percent%': 11, 'spaces % more spaces': 9} + {"percent%": 5, "spaces % more spaces": 12}, + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, ]: config.db.execute(percent_table.insert(), params) self._assert_table() @@ -147,14 +139,15 @@ class PercentSchemaNamesTest(fixtures.TablesTest): def test_executemany_roundtrip(self): percent_table = self.tables.percent_table config.db.execute( - percent_table.insert(), - {'percent%': 5, 'spaces % more spaces': 12} + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} ) config.db.execute( percent_table.insert(), - [{'percent%': 7, 'spaces % more spaces': 11}, - {'percent%': 9, 'spaces % more spaces': 10}, - {'percent%': 11, 'spaces % more spaces': 9}] + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], ) self._assert_table() @@ -163,85 +156,81 @@ class PercentSchemaNamesTest(fixtures.TablesTest): lightweight_percent_table = self.tables.lightweight_percent_table for table in ( - percent_table, - percent_table.alias(), - lightweight_percent_table, - lightweight_percent_table.alias()): + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias(), + ): eq_( list( config.db.execute( - table.select().order_by(table.c['percent%']) + table.select().order_by(table.c["percent%"]) ) ), - [ - (5, 12), - (7, 11), - (9, 10), - (11, 9) - ] + [(5, 12), (7, 11), (9, 10), (11, 9)], ) eq_( list( config.db.execute( - table.select(). - where(table.c['spaces % more spaces'].in_([9, 10])). - order_by(table.c['percent%']), + table.select() + .where(table.c["spaces % more spaces"].in_([9, 10])) + .order_by(table.c["percent%"]) ) ), - [ - (9, 10), - (11, 9) - ] + [(9, 10), (11, 9)], ) - row = config.db.execute(table.select(). - order_by(table.c['percent%'])).first() - eq_(row['percent%'], 5) - eq_(row['spaces % more spaces'], 12) + row = config.db.execute( + table.select().order_by(table.c["percent%"]) + ).first() + eq_(row["percent%"], 5) + eq_(row["spaces % more spaces"], 12) - eq_(row[table.c['percent%']], 5) - eq_(row[table.c['spaces % more spaces']], 12) + eq_(row[table.c["percent%"]], 5) + eq_(row[table.c["spaces % more spaces"]], 12) config.db.execute( percent_table.update().values( - {percent_table.c['spaces % more spaces']: 15} + {percent_table.c["spaces % more spaces"]: 15} ) ) eq_( list( config.db.execute( - percent_table. - select(). - order_by(percent_table.c['percent%']) + percent_table.select().order_by( + percent_table.c["percent%"] + ) ) ), - [(5, 15), (7, 15), (9, 15), (11, 15)] + [(5, 15), (7, 15), (9, 15), (11, 15)], ) -class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): +class ServerSideCursorsTest( + fixtures.TestBase, testing.AssertsExecutionResults +): - __requires__ = ('server_side_cursors', ) + __requires__ = ("server_side_cursors",) __backend__ = True def _is_server_side(self, cursor): if self.engine.dialect.driver == "psycopg2": return cursor.name - elif self.engine.dialect.driver == 'pymysql': - sscursor = __import__('pymysql.cursors').cursors.SSCursor + elif self.engine.dialect.driver == "pymysql": + sscursor = __import__("pymysql.cursors").cursors.SSCursor return isinstance(cursor, sscursor) elif self.engine.dialect.driver == "mysqldb": - sscursor = __import__('MySQLdb.cursors').cursors.SSCursor + sscursor = __import__("MySQLdb.cursors").cursors.SSCursor return isinstance(cursor, sscursor) else: return False def _fixture(self, server_side_cursors): self.engine = engines.testing_engine( - options={'server_side_cursors': server_side_cursors} + options={"server_side_cursors": server_side_cursors} ) return self.engine @@ -251,12 +240,12 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_global_string(self): engine = self._fixture(True) - result = engine.execute('select 1') + result = engine.execute("select 1") assert self._is_server_side(result.cursor) def test_global_text(self): engine = self._fixture(True) - result = engine.execute(text('select 1')) + result = engine.execute(text("select 1")) assert self._is_server_side(result.cursor) def test_global_expr(self): @@ -266,7 +255,7 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_global_off_explicit(self): engine = self._fixture(False) - result = engine.execute(text('select 1')) + result = engine.execute(text("select 1")) # It should be off globally ... @@ -286,10 +275,11 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): engine = self._fixture(False) # and this one - result = \ - engine.connect().execution_options(stream_results=True).\ - execute('select 1' - ) + result = ( + engine.connect() + .execution_options(stream_results=True) + .execute("select 1") + ) assert self._is_server_side(result.cursor) def test_stmt_enabled_conn_option_disabled(self): @@ -298,9 +288,9 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): s = select([1]).execution_options(stream_results=True) # not this one - result = \ - engine.connect().execution_options(stream_results=False).\ - execute(s) + result = ( + engine.connect().execution_options(stream_results=False).execute(s) + ) assert not self._is_server_side(result.cursor) def test_stmt_option_disabled(self): @@ -329,18 +319,18 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_for_update_string(self): engine = self._fixture(True) - result = engine.execute('SELECT 1 FOR UPDATE') + result = engine.execute("SELECT 1 FOR UPDATE") assert self._is_server_side(result.cursor) def test_text_no_ss(self): engine = self._fixture(False) - s = text('select 42') + s = text("select 42") result = engine.execute(s) assert not self._is_server_side(result.cursor) def test_text_ss_option(self): engine = self._fixture(False) - s = text('select 42').execution_options(stream_results=True) + s = text("select 42").execution_options(stream_results=True) result = engine.execute(s) assert self._is_server_side(result.cursor) @@ -349,19 +339,25 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): md = self.metadata engine = self._fixture(True) - test_table = Table('test_table', md, - Column('id', Integer, primary_key=True), - Column('data', String(50))) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) test_table.create(checkfirst=True) - test_table.insert().execute(data='data1') - test_table.insert().execute(data='data2') - eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, 'data1'), (2, 'data2')]) - test_table.update().where( - test_table.c.id == 2).values( - data=test_table.c.data + - ' updated').execute() - eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, 'data1'), (2, 'data2 updated')]) + test_table.insert().execute(data="data1") + test_table.insert().execute(data="data2") + eq_( + test_table.select().order_by(test_table.c.id).execute().fetchall(), + [(1, "data1"), (2, "data2")], + ) + test_table.update().where(test_table.c.id == 2).values( + data=test_table.c.data + " updated" + ).execute() + eq_( + test_table.select().order_by(test_table.c.id).execute().fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) test_table.delete().execute() - eq_(select([func.count('*')]).select_from(test_table).scalar(), 0) + eq_(select([func.count("*")]).select_from(test_table).scalar(), 0) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 73ce02492..032b68eb6 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -16,10 +16,12 @@ class CollateTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(100)) - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) @classmethod def insert_data(cls): @@ -28,26 +30,21 @@ class CollateTest(fixtures.TablesTest): [ {"id": 1, "data": "collate data1"}, {"id": 2, "data": "collate data2"}, - ] + ], ) def _assert_result(self, select, result): - eq_( - config.db.execute(select).fetchall(), - result - ) + eq_(config.db.execute(select).fetchall(), result) @testing.requires.order_by_collation def test_collate_order_by(self): collation = testing.requires.get_order_by_collation(testing.config) self._assert_result( - select([self.tables.some_table]). - order_by(self.tables.some_table.c.data.collate(collation).asc()), - [ - (1, "collate data1"), - (2, "collate data2"), - ] + select([self.tables.some_table]).order_by( + self.tables.some_table.c.data.collate(collation).asc() + ), + [(1, "collate data1"), (2, "collate data2")], ) @@ -59,17 +56,20 @@ class OrderByLabelTest(fixtures.TablesTest): setting. """ + __backend__ = True @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('q', String(50)), - Column('p', String(50)) - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + ) @classmethod def insert_data(cls): @@ -79,65 +79,55 @@ class OrderByLabelTest(fixtures.TablesTest): {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, - ] + ], ) def _assert_result(self, select, result): - eq_( - config.db.execute(select).fetchall(), - result - ) + eq_(config.db.execute(select).fetchall(), result) def test_plain(self): table = self.tables.some_table - lx = table.c.x.label('lx') - self._assert_result( - select([lx]).order_by(lx), - [(1, ), (2, ), (3, )] - ) + lx = table.c.x.label("lx") + self._assert_result(select([lx]).order_by(lx), [(1,), (2,), (3,)]) def test_composed_int(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') - self._assert_result( - select([lx]).order_by(lx), - [(3, ), (5, ), (7, )] - ) + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select([lx]).order_by(lx), [(3,), (5,), (7,)]) def test_composed_multiple(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') - ly = (func.lower(table.c.q) + table.c.p).label('ly') + lx = (table.c.x + table.c.y).label("lx") + ly = (func.lower(table.c.q) + table.c.p).label("ly") self._assert_result( select([lx, ly]).order_by(lx, ly.desc()), - [(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))] + [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))], ) def test_plain_desc(self): table = self.tables.some_table - lx = table.c.x.label('lx') + lx = table.c.x.label("lx") self._assert_result( - select([lx]).order_by(lx.desc()), - [(3, ), (2, ), (1, )] + select([lx]).order_by(lx.desc()), [(3,), (2,), (1,)] ) def test_composed_int_desc(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') + lx = (table.c.x + table.c.y).label("lx") self._assert_result( - select([lx]).order_by(lx.desc()), - [(7, ), (5, ), (3, )] + select([lx]).order_by(lx.desc()), [(7,), (5,), (3,)] ) @testing.requires.group_by_complex_expression def test_group_by_composed(self): table = self.tables.some_table - expr = (table.c.x + table.c.y).label('lx') - stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr) - self._assert_result( - stmt, - [(1, 3), (1, 5), (1, 7)] + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select([func.count(table.c.id), expr]) + .group_by(expr) + .order_by(expr) ) + self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)]) class LimitOffsetTest(fixtures.TablesTest): @@ -145,10 +135,13 @@ class LimitOffsetTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) @classmethod def insert_data(cls): @@ -159,20 +152,17 @@ class LimitOffsetTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3}, {"id": 3, "x": 3, "y": 4}, {"id": 4, "x": 4, "y": 5}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_simple_limit(self): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).limit(2), - [(1, 1, 2), (2, 2, 3)] + [(1, 1, 2), (2, 2, 3)], ) @testing.requires.offset @@ -180,7 +170,7 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).offset(2), - [(3, 3, 4), (4, 4, 5)] + [(3, 3, 4), (4, 4, 5)], ) @testing.requires.offset @@ -188,7 +178,7 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).limit(2).offset(1), - [(2, 2, 3), (3, 3, 4)] + [(2, 2, 3), (3, 3, 4)], ) @testing.requires.offset @@ -198,41 +188,40 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table stmt = select([table]).order_by(table.c.id).limit(2).offset(1) sql = stmt.compile( - dialect=config.db.dialect, - compile_kwargs={"literal_binds": True}) + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) sql = str(sql) - self._assert_result( - sql, - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(sql, [(2, 2, 3), (3, 3, 4)]) @testing.requires.bound_limit_offset def test_bound_limit(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id).limit(bindparam('l')), + select([table]).order_by(table.c.id).limit(bindparam("l")), [(1, 1, 2), (2, 2, 3)], - params={"l": 2} + params={"l": 2}, ) @testing.requires.bound_limit_offset def test_bound_offset(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id).offset(bindparam('o')), + select([table]).order_by(table.c.id).offset(bindparam("o")), [(3, 3, 4), (4, 4, 5)], - params={"o": 2} + params={"o": 2}, ) @testing.requires.bound_limit_offset def test_bound_limit_offset(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id). - limit(bindparam("l")).offset(bindparam("o")), + select([table]) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), [(2, 2, 3), (3, 3, 4)], - params={"l": 2, "o": 1} + params={"l": 2, "o": 1}, ) @@ -241,10 +230,13 @@ class CompoundSelectTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) @classmethod def insert_data(cls): @@ -255,14 +247,11 @@ class CompoundSelectTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3}, {"id": 3, "x": 3, "y": 4}, {"id": 4, "x": 4, "y": 5}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_plain_union(self): table = self.tables.some_table @@ -270,10 +259,7 @@ class CompoundSelectTest(fixtures.TablesTest): s2 = select([table]).where(table.c.id == 3) u1 = union(s1, s2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) def test_select_from_plain_union(self): table = self.tables.some_table @@ -281,80 +267,88 @@ class CompoundSelectTest(fixtures.TablesTest): s2 = select([table]).where(table.c.id == 3) u1 = union(s1, s2).alias().select() - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.order_by_col_from_union @testing.requires.parens_in_union_contained_select_w_limit_offset def test_limit_offset_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + ) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.parens_in_union_contained_select_wo_limit_offset def test_order_by_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - order_by(table.c.id) + s1 = select([table]).where(table.c.id == 2).order_by(table.c.id) + s2 = select([table]).where(table.c.id == 3).order_by(table.c.id) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) def test_distinct_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - distinct() - s2 = select([table]).where(table.c.id == 3).\ - distinct() + s1 = select([table]).where(table.c.id == 2).distinct() + s2 = select([table]).where(table.c.id == 3).distinct() u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.parens_in_union_contained_select_w_limit_offset def test_limit_offset_in_unions_from_alias(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + ) # this necessarily has double parens u1 = union(s1, s2).alias() self._assert_result( - u1.select().limit(2).order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] + u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] ) def test_limit_offset_aliased_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id).alias().select() - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id).alias().select() + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) class ExpandingBoundInTest(fixtures.TablesTest): @@ -362,11 +356,14 @@ class ExpandingBoundInTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('z', String(50))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) @classmethod def insert_data(cls): @@ -377,178 +374,184 @@ class ExpandingBoundInTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3, "z": "z2"}, {"id": 3, "x": 3, "y": 4, "z": "z3"}, {"id": 4, "x": 4, "y": 5, "z": "z4"}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_multiple_empty_sets(self): # test that any anonymous aliasing used by the dialect # is fine with duplicates table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).where( - table.c.y.in_(bindparam('p', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": [], "p": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .where(table.c.y.in_(bindparam("p", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": [], "p": []}) + @testing.requires.tuple_in def test_empty_heterogeneous_tuples(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.z).in_( - bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + @testing.requires.tuple_in def test_empty_homogeneous_tuples(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.y).in_( - bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_bound_in_scalar(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [(2, ), (3, ), (4, )], - params={"q": [2, 3, 4]}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) + @testing.requires.tuple_in def test_bound_in_two_tuple(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.y).in_( - bindparam('q', expanding=True))).order_by(table.c.id) + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) + ) self._assert_result( - stmt, - [(2, ), (3, ), (4, )], - params={"q": [(2, 3), (3, 4), (4, 5)]}, + stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} ) @testing.requires.tuple_in def test_bound_in_heterogeneous_two_tuple(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.z).in_( - bindparam('q', expanding=True))).order_by(table.c.id) + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) + ) self._assert_result( stmt, - [(2, ), (3, ), (4, )], + [(2,), (3,), (4,)], params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) def test_empty_set_against_integer(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_empty_set_against_integer_negation(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.notin_(bindparam('q', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [(1, ), (2, ), (3, ), (4, )], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.notin_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + def test_empty_set_against_string(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.z.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.z.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_empty_set_against_string_negation(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.z.notin_(bindparam('q', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [(1, ), (2, ), (3, ), (4, )], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.z.notin_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + def test_null_in_empty_set_is_false(self): - stmt = select([ - case( - [ - ( - null().in_(bindparam('foo', value=(), expanding=True)), - true() - ) - ], - else_=false() - ) - ]) - in_( - config.db.execute(stmt).fetchone()[0], - (False, 0) + stmt = select( + [ + case( + [ + ( + null().in_( + bindparam("foo", value=(), expanding=True) + ), + true(), + ) + ], + else_=false(), + ) + ] ) + in_(config.db.execute(stmt).fetchone()[0], (False, 0)) class LikeFunctionsTest(fixtures.TablesTest): __backend__ = True - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) @classmethod def insert_data(cls): @@ -565,7 +568,7 @@ class LikeFunctionsTest(fixtures.TablesTest): {"id": 8, "data": "ab9cdefg"}, {"id": 9, "data": "abcde#fg"}, {"id": 10, "data": "abcd9fg"}, - ] + ], ) def _test(self, expr, expected): @@ -573,8 +576,10 @@ class LikeFunctionsTest(fixtures.TablesTest): with config.db.connect() as conn: rows = { - value for value, in - conn.execute(select([some_table.c.id]).where(expr)) + value + for value, in conn.execute( + select([some_table.c.id]).where(expr) + ) } eq_(rows, expected) @@ -591,7 +596,8 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test( col.startswith(literal_column("'ab%c'")), - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + ) def test_startswith_escape(self): col = self.tables.some_table.c.data @@ -608,8 +614,9 @@ class LikeFunctionsTest(fixtures.TablesTest): def test_endswith_sqlexpr(self): col = self.tables.some_table.c.data - self._test(col.endswith(literal_column("'e%fg'")), - {1, 2, 3, 4, 5, 6, 7, 8, 9}) + self._test( + col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} + ) def test_endswith_autoescape(self): col = self.tables.some_table.c.data @@ -640,5 +647,3 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) - - diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index f1c00de6b..15a850fe9 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -9,140 +9,144 @@ from ..schema import Table, Column class SequenceTest(fixtures.TablesTest): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True - run_create_tables = 'each' + run_create_tables = "each" @classmethod def define_tables(cls, metadata): - Table('seq_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq'), primary_key=True), - Column('data', String(50)) - ) + Table( + "seq_pk", + metadata, + Column("id", Integer, Sequence("tab_id_seq"), primary_key=True), + Column("data", String(50)), + ) - Table('seq_opt_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq', optional=True), - primary_key=True), - Column('data', String(50)) - ) + Table( + "seq_opt_pk", + metadata, + Column( + "id", + Integer, + Sequence("tab_id_seq", optional=True), + primary_key=True, + ), + Column("data", String(50)), + ) def test_insert_roundtrip(self): - config.db.execute( - self.tables.seq_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.seq_pk.insert(), data="some data") self._assert_round_trip(self.tables.seq_pk, config.db) def test_insert_lastrowid(self): - r = config.db.execute( - self.tables.seq_pk.insert(), - data="some data" - ) - eq_( - r.inserted_primary_key, - [1] - ) + r = config.db.execute(self.tables.seq_pk.insert(), data="some data") + eq_(r.inserted_primary_key, [1]) def test_nextval_direct(self): - r = config.db.execute( - self.tables.seq_pk.c.id.default - ) - eq_( - r, 1 - ) + r = config.db.execute(self.tables.seq_pk.c.id.default) + eq_(r, 1) @requirements.sequences_optional def test_optional_seq(self): r = config.db.execute( - self.tables.seq_opt_pk.insert(), - data="some data" - ) - eq_( - r.inserted_primary_key, - [1] + self.tables.seq_opt_pk.insert(), data="some data" ) + eq_(r.inserted_primary_key, [1]) def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (1, "some data") - ) + eq_(row, (1, "some data")) class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True def test_literal_binds_inline_compile(self): table = Table( - 'x', MetaData(), - Column('y', Integer, Sequence('y_seq')), - Column('q', Integer)) + "x", + MetaData(), + Column("y", Integer, Sequence("y_seq")), + Column("q", Integer), + ) stmt = table.insert().values(q=5) seq_nextval = testing.db.dialect.statement_compiler( - statement=None, dialect=testing.db.dialect).visit_sequence( - Sequence("y_seq")) + statement=None, dialect=testing.db.dialect + ).visit_sequence(Sequence("y_seq")) self.assert_compile( stmt, - "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ), + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,), literal_binds=True, - dialect=testing.db.dialect) + dialect=testing.db.dialect, + ) class HasSequenceTest(fixtures.TestBase): - __requires__ = 'sequences', + __requires__ = ("sequences",) __backend__ = True def test_has_sequence(self): - s1 = Sequence('user_id_seq') + s1 = Sequence("user_id_seq") testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, - 'user_id_seq'), True) + eq_( + testing.db.dialect.has_sequence(testing.db, "user_id_seq"), + True, + ) finally: testing.db.execute(schema.DropSequence(s1)) @testing.requires.schemas def test_has_sequence_schema(self): - s1 = Sequence('user_id_seq', schema=config.test_schema) + s1 = Sequence("user_id_seq", schema=config.test_schema) testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence( - testing.db, 'user_id_seq', schema=config.test_schema), True) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + True, + ) finally: testing.db.execute(schema.DropSequence(s1)) def test_has_sequence_neg(self): - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), - False) + eq_(testing.db.dialect.has_sequence(testing.db, "user_id_seq"), False) @testing.requires.schemas def test_has_sequence_schemas_neg(self): - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema=config.test_schema), - False) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + False, + ) @testing.requires.schemas def test_has_sequence_default_not_in_remote(self): - s1 = Sequence('user_id_seq') + s1 = Sequence("user_id_seq") testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema=config.test_schema), - False) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + False, + ) finally: testing.db.execute(schema.DropSequence(s1)) @testing.requires.schemas def test_has_sequence_remote_not_in_default(self): - s1 = Sequence('user_id_seq', schema=config.test_schema) + s1 = Sequence("user_id_seq", schema=config.test_schema) testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), - False) + eq_( + testing.db.dialect.has_sequence(testing.db, "user_id_seq"), + False, + ) finally: testing.db.execute(schema.DropSequence(s1)) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 27c7bb115..6dfb80915 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -4,9 +4,24 @@ from .. import fixtures, config from ..assertions import eq_ from ..config import requirements from sqlalchemy import Integer, Unicode, UnicodeText, select, TIMESTAMP -from sqlalchemy import Date, DateTime, Time, MetaData, String, \ - Text, Numeric, Float, literal, Boolean, cast, null, JSON, and_, \ - type_coerce, BigInteger +from sqlalchemy import ( + Date, + DateTime, + Time, + MetaData, + String, + Text, + Numeric, + Float, + literal, + Boolean, + cast, + null, + JSON, + and_, + type_coerce, + BigInteger, +) from ..schema import Table, Column from ... import testing import decimal @@ -24,13 +39,17 @@ class _LiteralRoundTripFixture(object): # into a typed column. we can then SELECT it back as its # official type; ideally we'd be able to use CAST here # but MySQL in particular can't CAST fully - t = Table('t', self.metadata, Column('x', type_)) + t = Table("t", self.metadata, Column("x", type_)) t.create() for value in input_: - ins = t.insert().values(x=literal(value)).compile( - dialect=testing.db.dialect, - compile_kwargs=dict(literal_binds=True) + ins = ( + t.insert() + .values(x=literal(value)) + .compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) ) testing.db.execute(ins) @@ -42,40 +61,33 @@ class _LiteralRoundTripFixture(object): class _UnicodeFixture(_LiteralRoundTripFixture): - __requires__ = 'unicode_data', + __requires__ = ("unicode_data",) - data = u("Alors vous imaginez ma surprise, au lever du jour, " - "quand une drôle de petite voix m’a réveillé. Elle " - "disait: « S’il vous plaît… dessine-moi un mouton! »") + data = u( + "Alors vous imaginez ma surprise, au lever du jour, " + "quand une drôle de petite voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi un mouton! »" + ) @classmethod def define_tables(cls, metadata): - Table('unicode_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('unicode_data', cls.datatype), - ) + Table( + "unicode_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("unicode_data", cls.datatype), + ) def test_round_trip(self): unicode_table = self.tables.unicode_table - config.db.execute( - unicode_table.insert(), - { - 'unicode_data': self.data, - } - ) + config.db.execute(unicode_table.insert(), {"unicode_data": self.data}) - row = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) - ).first() + row = config.db.execute(select([unicode_table.c.unicode_data])).first() - eq_( - row, - (self.data, ) - ) + eq_(row, (self.data,)) assert isinstance(row[0], util.text_type) def test_round_trip_executemany(self): @@ -83,44 +95,29 @@ class _UnicodeFixture(_LiteralRoundTripFixture): config.db.execute( unicode_table.insert(), - [ - { - 'unicode_data': self.data, - } - for i in range(3) - ] + [{"unicode_data": self.data} for i in range(3)], ) rows = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) + select([unicode_table.c.unicode_data]) ).fetchall() - eq_( - rows, - [(self.data, ) for i in range(3)] - ) + eq_(rows, [(self.data,) for i in range(3)]) for row in rows: assert isinstance(row[0], util.text_type) def _test_empty_strings(self): unicode_table = self.tables.unicode_table - config.db.execute( - unicode_table.insert(), - {"unicode_data": u('')} - ) - row = config.db.execute( - select([unicode_table.c.unicode_data]) - ).first() - eq_(row, (u(''),)) + config.db.execute(unicode_table.insert(), {"unicode_data": u("")}) + row = config.db.execute(select([unicode_table.c.unicode_data])).first() + eq_(row, (u(""),)) def test_literal(self): self._literal_round_trip(self.datatype, [self.data], [self.data]) class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): - __requires__ = 'unicode_data', + __requires__ = ("unicode_data",) __backend__ = True datatype = Unicode(255) @@ -131,7 +128,7 @@ class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): - __requires__ = 'unicode_data', 'text_type' + __requires__ = "unicode_data", "text_type" __backend__ = True datatype = UnicodeText() @@ -142,54 +139,47 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): - __requires__ = 'text_type', + __requires__ = ("text_type",) __backend__ = True @classmethod def define_tables(cls, metadata): - Table('text_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('text_data', Text), - ) + Table( + "text_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("text_data", Text), + ) def test_text_roundtrip(self): text_table = self.tables.text_table - config.db.execute( - text_table.insert(), - {"text_data": 'some text'} - ) - row = config.db.execute( - select([text_table.c.text_data]) - ).first() - eq_(row, ('some text',)) + config.db.execute(text_table.insert(), {"text_data": "some text"}) + row = config.db.execute(select([text_table.c.text_data])).first() + eq_(row, ("some text",)) def test_text_empty_strings(self): text_table = self.tables.text_table - config.db.execute( - text_table.insert(), - {"text_data": ''} - ) - row = config.db.execute( - select([text_table.c.text_data]) - ).first() - eq_(row, ('',)) + config.db.execute(text_table.insert(), {"text_data": ""}) + row = config.db.execute(select([text_table.c.text_data])).first() + eq_(row, ("",)) def test_literal(self): self._literal_round_trip(Text, ["some text"], ["some text"]) def test_literal_quoting(self): - data = '''some 'text' hey "hi there" that's text''' + data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(Text, [data], [data]) def test_literal_backslashes(self): - data = r'backslash one \ backslash two \\ end' + data = r"backslash one \ backslash two \\ end" self._literal_round_trip(Text, [data], [data]) def test_literal_percentsigns(self): - data = r'percent % signs %% percent' + data = r"percent % signs %% percent" self._literal_round_trip(Text, [data], [data]) @@ -199,9 +189,7 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): @requirements.unbounded_varchar def test_nolength_string(self): metadata = MetaData() - foo = Table('foo', metadata, - Column('one', String) - ) + foo = Table("foo", metadata, Column("one", String)) foo.create(config.db) foo.drop(config.db) @@ -210,11 +198,11 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): self._literal_round_trip(String(40), ["some text"], ["some text"]) def test_literal_quoting(self): - data = '''some 'text' hey "hi there" that's text''' + data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(String(40), [data], [data]) def test_literal_backslashes(self): - data = r'backslash one \ backslash two \\ end' + data = r"backslash one \ backslash two \\ end" self._literal_round_trip(String(40), [data], [data]) @@ -223,44 +211,32 @@ class _DateFixture(_LiteralRoundTripFixture): @classmethod def define_tables(cls, metadata): - Table('date_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('date_data', cls.datatype), - ) + Table( + "date_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date_data", cls.datatype), + ) def test_round_trip(self): date_table = self.tables.date_table - config.db.execute( - date_table.insert(), - {'date_data': self.data} - ) + config.db.execute(date_table.insert(), {"date_data": self.data}) - row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + row = config.db.execute(select([date_table.c.date_data])).first() compare = self.compare or self.data - eq_(row, - (compare, )) + eq_(row, (compare,)) assert isinstance(row[0], type(compare)) def test_null(self): date_table = self.tables.date_table - config.db.execute( - date_table.insert(), - {'date_data': None} - ) + config.db.execute(date_table.insert(), {"date_data": None}) - row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + row = config.db.execute(select([date_table.c.date_data])).first() eq_(row, (None,)) @testing.requires.datetime_literals @@ -270,48 +246,49 @@ class _DateFixture(_LiteralRoundTripFixture): class DateTimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime', + __requires__ = ("datetime",) __backend__ = True datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18) class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime_microseconds', + __requires__ = ("datetime_microseconds",) __backend__ = True datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'timestamp_microseconds', + __requires__ = ("timestamp_microseconds",) __backend__ = True datatype = TIMESTAMP data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) class TimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'time', + __requires__ = ("time",) __backend__ = True datatype = Time data = datetime.time(12, 57, 18) class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'time_microseconds', + __requires__ = ("time_microseconds",) __backend__ = True datatype = Time data = datetime.time(12, 57, 18, 396) class DateTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date', + __requires__ = ("date",) __backend__ = True datatype = Date data = datetime.date(2012, 10, 15) class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date', 'date_coerces_from_datetime' + __requires__ = "date", "date_coerces_from_datetime" __backend__ = True datatype = Date data = datetime.datetime(2012, 10, 15, 12, 57, 18) @@ -319,14 +296,14 @@ class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime_historic', + __requires__ = ("datetime_historic",) __backend__ = True datatype = DateTime data = datetime.datetime(1850, 11, 10, 11, 52, 35) class DateHistoricTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date_historic', + __requires__ = ("date_historic",) __backend__ = True datatype = Date data = datetime.date(1727, 4, 1) @@ -345,26 +322,21 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): def _round_trip(self, datatype, data): metadata = self.metadata int_table = Table( - 'integer_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('integer_data', datatype), + "integer_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("integer_data", datatype), ) metadata.create_all(config.db) - config.db.execute( - int_table.insert(), - {'integer_data': data} - ) + config.db.execute(int_table.insert(), {"integer_data": data}) - row = config.db.execute( - select([ - int_table.c.integer_data, - ]) - ).first() + row = config.db.execute(select([int_table.c.integer_data])).first() - eq_(row, (data, )) + eq_(row, (data,)) if util.py3k: assert isinstance(row[0], int) @@ -377,12 +349,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.emits_warning(r".*does \*not\* support Decimal objects natively") @testing.provide_metadata - def _do_test(self, type_, input_, output, - filter_=None, check_scale=False): + def _do_test(self, type_, input_, output, filter_=None, check_scale=False): metadata = self.metadata - t = Table('t', metadata, Column('x', type_)) + t = Table("t", metadata, Column("x", type_)) t.create() - t.insert().execute([{'x': x} for x in input_]) + t.insert().execute([{"x": x} for x in input_]) result = {row[0] for row in t.select().execute()} output = set(output) @@ -391,10 +362,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): output = set(filter_(x) for x in output) eq_(result, output) if check_scale: - eq_( - [str(x) for x in result], - [str(x) for x in output], - ) + eq_([str(x) for x in result], [str(x) for x in output]) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def test_render_literal_numeric(self): @@ -416,8 +384,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): self._literal_round_trip( Float(4), [15.7563, decimal.Decimal("15.7563")], - [15.7563, ], - filter_=lambda n: n is not None and round(n, 5) or None + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, ) @testing.requires.precision_generic_float_type @@ -425,8 +393,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): self._do_test( Float(None, decimal_return_scale=7, asdecimal=True), [15.7563827, decimal.Decimal("15.7563827")], - [decimal.Decimal("15.7563827"), ], - check_scale=True + [decimal.Decimal("15.7563827")], + check_scale=True, ) def test_numeric_as_decimal(self): @@ -445,18 +413,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.requires.fetch_null_from_numeric def test_numeric_null_as_decimal(self): - self._do_test( - Numeric(precision=8, scale=4), - [None], - [None], - ) + self._do_test(Numeric(precision=8, scale=4), [None], [None]) @testing.requires.fetch_null_from_numeric def test_numeric_null_as_float(self): self._do_test( - Numeric(precision=8, scale=4, asdecimal=False), - [None], - [None], + Numeric(precision=8, scale=4, asdecimal=False), [None], [None] ) @testing.requires.floats_to_four_decimals @@ -472,15 +434,13 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): Float(precision=8), [15.7563, decimal.Decimal("15.7563")], [15.7563], - filter_=lambda n: n is not None and round(n, 5) or None + filter_=lambda n: n is not None and round(n, 5) or None, ) def test_float_coerce_round_trip(self): expr = 15.7563 - val = testing.db.scalar( - select([literal(expr)]) - ) + val = testing.db.scalar(select([literal(expr)])) eq_(val, expr) # this does not work in MySQL, see #4036, however we choose not @@ -491,34 +451,28 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): def test_decimal_coerce_round_trip(self): expr = decimal.Decimal("15.7563") - val = testing.db.scalar( - select([literal(expr)]) - ) + val = testing.db.scalar(select([literal(expr)])) eq_(val, expr) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def test_decimal_coerce_round_trip_w_cast(self): expr = decimal.Decimal("15.7563") - val = testing.db.scalar( - select([cast(expr, Numeric(10, 4))]) - ) + val = testing.db.scalar(select([cast(expr, Numeric(10, 4))])) eq_(val, expr) @testing.requires.precision_numerics_general def test_precision_decimal(self): - numbers = set([ - decimal.Decimal("54.234246451650"), - decimal.Decimal("0.004354"), - decimal.Decimal("900.0"), - ]) - - self._do_test( - Numeric(precision=18, scale=12), - numbers, - numbers, + numbers = set( + [ + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ] ) + self._do_test(Numeric(precision=18, scale=12), numbers, numbers) + @testing.requires.precision_numerics_enotation_large def test_enotation_decimal(self): """test exceedingly small decimals. @@ -528,25 +482,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set([ - decimal.Decimal('1E-2'), - decimal.Decimal('1E-3'), - decimal.Decimal('1E-4'), - decimal.Decimal('1E-5'), - decimal.Decimal('1E-6'), - decimal.Decimal('1E-7'), - decimal.Decimal('1E-8'), - decimal.Decimal("0.01000005940696"), - decimal.Decimal("0.00000005940696"), - decimal.Decimal("0.00000000000696"), - decimal.Decimal("0.70000000000696"), - decimal.Decimal("696E-12"), - ]) - self._do_test( - Numeric(precision=18, scale=14), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + ] ) + self._do_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large def test_enotation_decimal_large(self): @@ -554,41 +506,32 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set([ - decimal.Decimal('4E+8'), - decimal.Decimal("5748E+15"), - decimal.Decimal('1.521E+15'), - decimal.Decimal('00000000000000.1E+12'), - ]) - self._do_test( - Numeric(precision=25, scale=2), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + ] ) + self._do_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits def test_many_significant_digits(self): - numbers = set([ - decimal.Decimal("31943874831932418390.01"), - decimal.Decimal("319438950232418390.273596"), - decimal.Decimal("87673.594069654243"), - ]) - self._do_test( - Numeric(precision=38, scale=12), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + ] ) + self._do_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits def test_numeric_no_decimal(self): - numbers = set([ - decimal.Decimal("1.000") - ]) + numbers = set([decimal.Decimal("1.000")]) self._do_test( - Numeric(precision=5, scale=3), - numbers, - numbers, - check_scale=True + Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -597,42 +540,32 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('boolean_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('value', Boolean), - Column('unconstrained_value', Boolean(create_constraint=False)), - ) + Table( + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), + ) def test_render_literal_bool(self): - self._literal_round_trip( - Boolean(), - [True, False], - [True, False] - ) + self._literal_round_trip(Boolean(), [True, False], [True, False]) def test_round_trip(self): boolean_table = self.tables.boolean_table config.db.execute( boolean_table.insert(), - { - 'id': 1, - 'value': True, - 'unconstrained_value': False - } + {"id": 1, "value": True, "unconstrained_value": False}, ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) + select( + [boolean_table.c.value, boolean_table.c.unconstrained_value] + ) ).first() - eq_( - row, - (True, False) - ) + eq_(row, (True, False)) assert isinstance(row[0], bool) def test_null(self): @@ -640,24 +573,16 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): config.db.execute( boolean_table.insert(), - { - 'id': 1, - 'value': None, - 'unconstrained_value': None - } + {"id": 1, "value": None, "unconstrained_value": None}, ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) + select( + [boolean_table.c.value, boolean_table.c.unconstrained_value] + ) ).first() - eq_( - row, - (None, None) - ) + eq_(row, (None, None)) def test_whereclause(self): # testing "WHERE <column>" renders a compatible expression @@ -667,92 +592,82 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): conn.execute( boolean_table.insert(), [ - {'id': 1, 'value': True, 'unconstrained_value': True}, - {'id': 2, 'value': False, 'unconstrained_value': False} - ] + {"id": 1, "value": True, "unconstrained_value": True}, + {"id": 2, "value": False, "unconstrained_value": False}, + ], ) eq_( conn.scalar( select([boolean_table.c.id]).where(boolean_table.c.value) ), - 1 + 1, ) eq_( conn.scalar( select([boolean_table.c.id]).where( - boolean_table.c.unconstrained_value) + boolean_table.c.unconstrained_value + ) ), - 1 + 1, ) eq_( conn.scalar( select([boolean_table.c.id]).where(~boolean_table.c.value) ), - 2 + 2, ) eq_( conn.scalar( select([boolean_table.c.id]).where( - ~boolean_table.c.unconstrained_value) + ~boolean_table.c.unconstrained_value + ) ), - 2 + 2, ) - - class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): - __requires__ = 'json_type', + __requires__ = ("json_type",) __backend__ = True datatype = JSON - data1 = { - "key1": "value1", - "key2": "value2" - } + data1 = {"key1": "value1", "key2": "value2"} data2 = { "Key 'One'": "value1", "key two": "value2", - "key three": "value ' three '" + "key three": "value ' three '", } data3 = { "key1": [1, 2, 3], "key2": ["one", "two", "three"], - "key3": [{"four": "five"}, {"six": "seven"}] + "key3": [{"four": "five"}, {"six": "seven"}], } data4 = ["one", "two", "three"] data5 = { "nested": { - "elem1": [ - {"a": "b", "c": "d"}, - {"e": "f", "g": "h"} - ], - "elem2": { - "elem3": {"elem4": "elem5"} - } + "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}], + "elem2": {"elem3": {"elem4": "elem5"}}, } } - data6 = { - "a": 5, - "b": "some value", - "c": {"foo": "bar"} - } + data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}} @classmethod def define_tables(cls, metadata): - Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), nullable=False), - Column('data', cls.datatype), - Column('nulldata', cls.datatype(none_as_null=True)) - ) + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype), + Column("nulldata", cls.datatype(none_as_null=True)), + ) def test_round_trip_data1(self): self._test_round_trip(self.data1) @@ -761,99 +676,82 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): data_table = self.tables.data_table config.db.execute( - data_table.insert(), - {'name': 'row1', 'data': data_element} + data_table.insert(), {"name": "row1", "data": data_element} ) - row = config.db.execute( - select([ - data_table.c.data, - ]) - ).first() + row = config.db.execute(select([data_table.c.data])).first() - eq_(row, (data_element, )) + eq_(row, (data_element,)) def test_round_trip_none_as_sql_null(self): - col = self.tables.data_table.c['nulldata'] + col = self.tables.data_table.c["nulldata"] with config.db.connect() as conn: conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": None} + self.tables.data_table.insert(), {"name": "r1", "data": None} ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(col.is_(null())) + select([self.tables.data_table.c.name]).where( + col.is_(null()) + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def test_round_trip_json_null_as_json_null(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] with config.db.connect() as conn: conn.execute( self.tables.data_table.insert(), - {"name": "r1", "data": JSON.NULL} + {"name": "r1", "data": JSON.NULL}, ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(cast(col, String) == 'null') + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def test_round_trip_none_as_json_null(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] with config.db.connect() as conn: conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": None} + self.tables.data_table.insert(), {"name": "r1", "data": None} ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(cast(col, String) == 'null') + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def _criteria_fixture(self): config.db.execute( self.tables.data_table.insert(), - [{"name": "r1", "data": self.data1}, - {"name": "r2", "data": self.data2}, - {"name": "r3", "data": self.data3}, - {"name": "r4", "data": self.data4}, - {"name": "r5", "data": self.data5}, - {"name": "r6", "data": self.data6}] + [ + {"name": "r1", "data": self.data1}, + {"name": "r2", "data": self.data2}, + {"name": "r3", "data": self.data3}, + {"name": "r4", "data": self.data4}, + {"name": "r5", "data": self.data5}, + {"name": "r6", "data": self.data6}, + ], ) def _test_index_criteria(self, crit, expected, test_literal=True): @@ -861,20 +759,20 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): with config.db.connect() as conn: stmt = select([self.tables.data_table.c.name]).where(crit) - eq_( - conn.scalar(stmt), - expected - ) + eq_(conn.scalar(stmt), expected) if test_literal: - literal_sql = str(stmt.compile( - config.db, compile_kwargs={"literal_binds": True})) + literal_sql = str( + stmt.compile( + config.db, compile_kwargs={"literal_binds": True} + ) + ) eq_(conn.scalar(literal_sql), expected) def test_crit_spaces_in_key(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] # limit the rows here to avoid PG error # "cannot extract field from a non-object", which is @@ -882,76 +780,74 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): self._test_index_criteria( and_( name.in_(["r1", "r2", "r3"]), - cast(col["key two"], String) == '"value2"' + cast(col["key two"], String) == '"value2"', ), - "r2" + "r2", ) @config.requirements.json_array_indexes def test_crit_simple_int(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] # limit the rows here to avoid PG error # "cannot extract array element from a non-array", which is # fixed in 9.4 but may exist in 9.3 self._test_index_criteria( - and_(name == 'r4', cast(col[1], String) == '"two"'), - "r4" + and_(name == "r4", cast(col[1], String) == '"two"'), "r4" ) def test_crit_mixed_path(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - cast(col[("key3", 1, "six")], String) == '"seven"', - "r3" + cast(col[("key3", 1, "six")], String) == '"seven"', "r3" ) def test_crit_string_path(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( cast(col[("nested", "elem2", "elem3", "elem4")], String) == '"elem5"', - "r5" + "r5", ) def test_crit_against_string_basic(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["b"], String) == '"some value"'), - "r6" + and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6" ) def test_crit_against_string_coerce_type(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', - cast(col["b"], String) == type_coerce("some value", JSON)), + and_( + name == "r6", + cast(col["b"], String) == type_coerce("some value", JSON), + ), "r6", - test_literal=False + test_literal=False, ) def test_crit_against_int_basic(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["a"], String) == '5'), - "r6" + and_(name == "r6", cast(col["a"], String) == "5"), "r6" ) def test_crit_against_int_coerce_type(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["a"], String) == type_coerce(5, JSON)), + and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)), "r6", - test_literal=False + test_literal=False, ) def test_unicode_round_trip(self): @@ -961,17 +857,17 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): { "name": "r1", "data": { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} - } - } + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, + }, + }, ) eq_( conn.scalar(select([self.tables.data_table.c.data])), { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, }, ) @@ -986,7 +882,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): s = Session(testing.db) - d1 = Data(name='d1', data=None, nulldata=None) + d1 = Data(name="d1", data=None, nulldata=None) s.add(d1) s.commit() @@ -995,24 +891,46 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): ) eq_( s.query( - cast(self.tables.data_table.c.data, String(convert_unicode="force")), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd1').first(), - ("null", None) + cast( + self.tables.data_table.c.data, + String(convert_unicode="force"), + ), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), ) eq_( s.query( - cast(self.tables.data_table.c.data, String(convert_unicode="force")), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd2').first(), - ("null", None) - ) - - -__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'JSONTest', - 'DateTest', 'DateTimeTest', 'TextTest', - 'NumericTest', 'IntegerTest', - 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest', - 'TimeMicrosecondsTest', 'TimestampMicrosecondsTest', 'TimeTest', - 'DateTimeMicrosecondsTest', - 'DateHistoricTest', 'StringTest', 'BooleanTest') + cast( + self.tables.data_table.c.data, + String(convert_unicode="force"), + ), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), + ) + + +__all__ = ( + "UnicodeVarcharTest", + "UnicodeTextTest", + "JSONTest", + "DateTest", + "DateTimeTest", + "TextTest", + "NumericTest", + "IntegerTest", + "DateTimeHistoricTest", + "DateTimeCoercedToDateTimeTest", + "TimeMicrosecondsTest", + "TimestampMicrosecondsTest", + "TimeTest", + "DateTimeMicrosecondsTest", + "DateHistoricTest", + "StringTest", + "BooleanTest", +) diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index e4c61e74a..b232c3a78 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -6,15 +6,17 @@ from ..schema import Table, Column class SimpleUpdateDeleteTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) @classmethod def insert_data(cls): @@ -24,40 +26,29 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): {"id": 1, "data": "d1"}, {"id": 2, "data": "d2"}, {"id": 3, "data": "d3"}, - ] + ], ) def test_update(self): t = self.tables.plain_pk - r = config.db.execute( - t.update().where(t.c.id == 2), - data="d2_new" - ) + r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new") assert not r.is_insert assert not r.returns_rows eq_( config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [ - (1, "d1"), - (2, "d2_new"), - (3, "d3") - ] + [(1, "d1"), (2, "d2_new"), (3, "d3")], ) def test_delete(self): t = self.tables.plain_pk - r = config.db.execute( - t.delete().where(t.c.id == 2) - ) + r = config.db.execute(t.delete().where(t.c.id == 2)) assert not r.is_insert assert not r.returns_rows eq_( config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [ - (1, "d1"), - (3, "d3") - ] + [(1, "d1"), (3, "d3")], ) -__all__ = ('SimpleUpdateDeleteTest', ) + +__all__ = ("SimpleUpdateDeleteTest",) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 409d3bda5..5b015d214 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -14,6 +14,7 @@ import sys import types if jython: + def jython_gc_collect(*args): """aggressive gc.collect for tests.""" gc.collect() @@ -25,9 +26,11 @@ if jython: # "lazy" gc, for VM's that don't GC on refcount == 0 gc_collect = lazy_gc = jython_gc_collect elif pypy: + def pypy_gc_collect(*args): gc.collect() gc.collect() + gc_collect = lazy_gc = pypy_gc_collect else: # assume CPython - straight gc.collect, lazy_gc() is a pass @@ -42,11 +45,13 @@ def picklers(): if py2k: try: import cPickle + picklers.add(cPickle) except ImportError: pass import pickle + picklers.add(pickle) # yes, this thing needs this much testing @@ -60,9 +65,9 @@ def round_decimal(value, prec): return round(value, prec) # can also use shift() here but that is 2.6 only - return (value * decimal.Decimal("1" + "0" * prec) - ).to_integral(decimal.ROUND_FLOOR) / \ - pow(10, prec) + return (value * decimal.Decimal("1" + "0" * prec)).to_integral( + decimal.ROUND_FLOOR + ) / pow(10, prec) class RandomSet(set): @@ -137,8 +142,9 @@ def function_named(fn, name): try: fn.__name__ = name except TypeError: - fn = types.FunctionType(fn.__code__, fn.__globals__, name, - fn.__defaults__, fn.__closure__) + fn = types.FunctionType( + fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__ + ) return fn @@ -190,7 +196,7 @@ def provide_metadata(fn, *args, **kw): metadata = schema.MetaData(config.db) self = args[0] - prev_meta = getattr(self, 'metadata', None) + prev_meta = getattr(self, "metadata", None) self.metadata = metadata try: return fn(*args, **kw) @@ -213,8 +219,8 @@ def force_drop_names(*names): try: return fn(*args, **kw) finally: - drop_all_tables( - config.db, inspect(config.db), include_names=names) + drop_all_tables(config.db, inspect(config.db), include_names=names) + return go @@ -234,8 +240,13 @@ class adict(dict): def drop_all_tables(engine, inspector, schema=None, include_names=None): - from sqlalchemy import Column, Table, Integer, MetaData, \ - ForeignKeyConstraint + from sqlalchemy import ( + Column, + Table, + Integer, + MetaData, + ForeignKeyConstraint, + ) from sqlalchemy.schema import DropTable, DropConstraint if include_names is not None: @@ -243,30 +254,35 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None): with engine.connect() as conn: for tname, fkcs in reversed( - inspector.get_sorted_table_and_fkc_names(schema=schema)): + inspector.get_sorted_table_and_fkc_names(schema=schema) + ): if tname: if include_names is not None and tname not in include_names: continue - conn.execute(DropTable( - Table(tname, MetaData(), schema=schema) - )) + conn.execute( + DropTable(Table(tname, MetaData(), schema=schema)) + ) elif fkcs: if not engine.dialect.supports_alter: continue for tname, fkc in fkcs: - if include_names is not None and \ - tname not in include_names: + if ( + include_names is not None + and tname not in include_names + ): continue tb = Table( - tname, MetaData(), - Column('x', Integer), - Column('y', Integer), - schema=schema + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + conn.execute( + DropConstraint( + ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc) + ) ) - conn.execute(DropConstraint( - ForeignKeyConstraint( - [tb.c.x], [tb.c.y], name=fkc) - )) def teardown_events(event_cls): @@ -276,5 +292,5 @@ def teardown_events(event_cls): return fn(*arg, **kw) finally: event_cls._clear() - return decorate + return decorate diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 46e7c54db..e0101b14d 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -15,17 +15,20 @@ from . import assertions def setup_filters(): """Set global warning behavior for the test suite.""" - warnings.filterwarnings('ignore', - category=sa_exc.SAPendingDeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SAWarning) + warnings.filterwarnings( + "ignore", category=sa_exc.SAPendingDeprecationWarning + ) + warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning) + warnings.filterwarnings("error", category=sa_exc.SAWarning) # some selected deprecations... - warnings.filterwarnings('error', category=DeprecationWarning) + warnings.filterwarnings("error", category=DeprecationWarning) warnings.filterwarnings( - "ignore", category=DeprecationWarning, message=".*StopIteration") + "ignore", category=DeprecationWarning, message=".*StopIteration" + ) warnings.filterwarnings( - "ignore", category=DeprecationWarning, message=".*inspect.getargspec") + "ignore", category=DeprecationWarning, message=".*inspect.getargspec" + ) def assert_warnings(fn, warning_msgs, regex=False): @@ -36,6 +39,6 @@ def assert_warnings(fn, warning_msgs, regex=False): """ with assertions._expect_warnings( - sa_exc.SAWarning, warning_msgs, regex=regex): + sa_exc.SAWarning, warning_msgs, regex=regex + ): return fn() - |
