diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/testing | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/testing')
33 files changed, 2399 insertions, 2114 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 413a492b8..f46ca4528 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -10,23 +10,62 @@ from .warnings import assert_warnings from . import config -from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ - fails_on, fails_on_everything_except, skip, only_on, exclude, \ - against as _against, _server_version, only_if, fails +from .exclusions import ( + db_spec, + _is_excluded, + fails_if, + skip_if, + future, + fails_on, + fails_on_everything_except, + skip, + only_on, + exclude, + against as _against, + _server_version, + only_if, + fails, +) def against(*queries): return _against(config._current, *queries) -from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ - eq_, ne_, le_, is_, is_not_, startswith_, assert_raises, \ - assert_raises_message, AssertsCompiledSQL, ComparesTables, \ - AssertsExecutionResults, expect_deprecated, expect_warnings, \ - in_, not_in_, eq_ignore_whitespace, eq_regex, is_true, is_false -from .util import run_as_contextmanager, rowset, fail, \ - provide_metadata, adict, force_drop_names, \ - teardown_events +from .assertions import ( + emits_warning, + emits_warning_on, + uses_deprecated, + eq_, + ne_, + le_, + is_, + is_not_, + startswith_, + assert_raises, + assert_raises_message, + AssertsCompiledSQL, + ComparesTables, + AssertsExecutionResults, + expect_deprecated, + expect_warnings, + in_, + not_in_, + eq_ignore_whitespace, + eq_regex, + is_true, + is_false, +) + +from .util import ( + run_as_contextmanager, + rowset, + fail, + provide_metadata, + adict, + force_drop_names, + teardown_events, +) crashes = skip diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e42376921..73ab4556a 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -86,6 +86,7 @@ def emits_warning_on(db, *messages): were in fact seen. """ + @decorator def decorate(fn, *args, **kw): with expect_warnings_on(db, assert_=False, *messages): @@ -114,12 +115,14 @@ def uses_deprecated(*messages): def decorate(fn, *args, **kw): with expect_deprecated(*messages, assert_=False): return fn(*args, **kw) + return decorate @contextlib.contextmanager -def _expect_warnings(exc_cls, messages, regex=True, assert_=True, - py2konly=False): +def _expect_warnings( + exc_cls, messages, regex=True, assert_=True, py2konly=False +): if regex: filters = [re.compile(msg, re.I | re.S) for msg in messages] @@ -145,8 +148,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, return for filter_ in filters: - if (regex and filter_.match(msg)) or \ - (not regex and filter_ == msg): + if (regex and filter_.match(msg)) or ( + not regex and filter_ == msg + ): seen.discard(filter_) break else: @@ -156,8 +160,9 @@ def _expect_warnings(exc_cls, messages, regex=True, assert_=True, yield if assert_ and (not py2konly or not compat.py3k): - assert not seen, "Warnings were not seen: %s" % \ - ", ".join("%r" % (s.pattern if regex else s) for s in seen) + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) def global_cleanup_assertions(): @@ -170,6 +175,7 @@ def global_cleanup_assertions(): """ _assert_no_stray_pool_connections() + _STRAY_CONNECTION_FAILURES = 0 @@ -187,8 +193,10 @@ def _assert_no_stray_pool_connections(): # OK, let's be somewhat forgiving. _STRAY_CONNECTION_FAILURES += 1 - print("Encountered a stray connection in test cleanup: %s" - % str(pool._refs)) + print( + "Encountered a stray connection in test cleanup: %s" + % str(pool._refs) + ) # then do a real GC sweep. We shouldn't even be here # so a single sweep should really be doing it, otherwise # there's probably a real unreachable cycle somewhere. @@ -206,8 +214,8 @@ def _assert_no_stray_pool_connections(): pool._refs.clear() _STRAY_CONNECTION_FAILURES = 0 warnings.warn( - "Stray connection refused to leave " - "after gc.collect(): %s" % err) + "Stray connection refused to leave " "after gc.collect(): %s" % err + ) elif _STRAY_CONNECTION_FAILURES > 10: assert False, "Encountered more than 10 stray connections" _STRAY_CONNECTION_FAILURES = 0 @@ -263,14 +271,16 @@ def not_in_(a, b, msg=None): def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( - a, fragment) + a, + fragment, + ) def eq_ignore_whitespace(a, b, msg=None): - a = re.sub(r'^\s+?|\n', "", a) - a = re.sub(r' {2,}', " ", a) - b = re.sub(r'^\s+?|\n', "", b) - b = re.sub(r' {2,}', " ", b) + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) assert a == b, msg or "%r != %r" % (a, b) @@ -291,32 +301,41 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls as e: - assert re.search( - msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (msg, e) - print(util.text_type(e).encode('utf-8')) + assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % ( + msg, + e, + ) + print(util.text_type(e).encode("utf-8")) + class AssertsCompiledSQL(object): - def assert_compile(self, clause, result, params=None, - checkparams=None, dialect=None, - checkpositional=None, - check_prefetch=None, - use_default_dialect=False, - allow_dialect_select=False, - literal_binds=False, - schema_translate_map=None): + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + literal_binds=False, + schema_translate_map=None, + ): if use_default_dialect: dialect = default.DefaultDialect() elif allow_dialect_select: dialect = None else: if dialect is None: - dialect = getattr(self, '__dialect__', None) + dialect = getattr(self, "__dialect__", None) if dialect is None: dialect = config.db.dialect - elif dialect == 'default': + elif dialect == "default": dialect = default.DefaultDialect() - elif dialect == 'default_enhanced': + elif dialect == "default_enhanced": dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() @@ -325,13 +344,13 @@ class AssertsCompiledSQL(object): compile_kwargs = {} if schema_translate_map: - kw['schema_translate_map'] = schema_translate_map + kw["schema_translate_map"] = schema_translate_map if params is not None: - kw['column_keys'] = list(params) + kw["column_keys"] = list(params) if literal_binds: - compile_kwargs['literal_binds'] = True + compile_kwargs["literal_binds"] = True if isinstance(clause, orm.Query): context = clause._compile_context() @@ -343,25 +362,27 @@ class AssertsCompiledSQL(object): clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: - kw['compile_kwargs'] = compile_kwargs + kw["compile_kwargs"] = compile_kwargs c = clause.compile(dialect=dialect, **kw) - param_str = repr(getattr(c, 'params', {})) + param_str = repr(getattr(c, "params", {})) if util.py3k: - param_str = param_str.encode('utf-8').decode('ascii', 'ignore') + param_str = param_str.encode("utf-8").decode("ascii", "ignore") print( - ("\nSQL String:\n" + - util.text_type(c) + - param_str).encode('utf-8')) + ("\nSQL String:\n" + util.text_type(c) + param_str).encode( + "utf-8" + ) + ) else: print( - "\nSQL String:\n" + - util.text_type(c).encode('utf-8') + - param_str) + "\nSQL String:\n" + + util.text_type(c).encode("utf-8") + + param_str + ) - cc = re.sub(r'[\n\t]', '', util.text_type(c)) + cc = re.sub(r"[\n\t]", "", util.text_type(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -375,7 +396,6 @@ class AssertsCompiledSQL(object): class ComparesTables(object): - def assert_tables_equal(self, table, reflected_table, strict_types=False): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): @@ -386,8 +406,10 @@ class ComparesTables(object): if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" - assert isinstance(reflected_c.type, type(c.type)), \ - msg % (reflected_c.type, c.type) + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) else: self.assert_types_base(reflected_c, c) @@ -396,20 +418,22 @@ class ComparesTables(object): eq_( {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys} + {f.column.name for f in reflected_c.foreign_keys}, ) if c.server_default: - assert isinstance(reflected_c.server_default, - schema.FetchedValue) + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): - assert c1.type._compare_type_affinity(c2.type),\ - "On column %r, type '%s' doesn't correspond to type '%s'" % \ - (c1.name, c1.type, c2.type) + assert c1.type._compare_type_affinity(c2.type), ( + "On column %r, type '%s' doesn't correspond to type '%s'" + % (c1.name, c1.type, c2.type) + ) class AssertsExecutionResults(object): @@ -419,15 +443,19 @@ class AssertsExecutionResults(object): self.assert_list(result, class_, objects) def assert_list(self, result, class_, list): - self.assert_(len(result) == len(list), - "result list is not the same size as test list, " + - "for class " + class_.__name__) + self.assert_( + len(result) == len(list), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) for i in range(0, len(list)): self.assert_row(class_, result[i], list[i]) def assert_row(self, class_, rowobj, desc): - self.assert_(rowobj.__class__ is class_, - "item class is not " + repr(class_)) + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) for key, value in desc.items(): if isinstance(value, tuple): if isinstance(value[1], list): @@ -435,9 +463,11 @@ class AssertsExecutionResults(object): else: self.assert_row(value[0], getattr(rowobj, key), value[1]) else: - self.assert_(getattr(rowobj, key) == value, - "attribute %s value %s does not match %s" % ( - key, getattr(rowobj, key), value)) + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) def assert_unordered_result(self, result, cls, *expected): """As assert_result, but the order of objects is not considered. @@ -453,14 +483,19 @@ class AssertsExecutionResults(object): found = util.IdentitySet(result) expected = {immutabledict(e) for e in expected} - for wrong in util.itertools_filterfalse(lambda o: - isinstance(o, cls), found): - fail('Unexpected type "%s", expected "%s"' % ( - type(wrong).__name__, cls.__name__)) + for wrong in util.itertools_filterfalse( + lambda o: isinstance(o, cls), found + ): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) if len(found) != len(expected): - fail('Unexpected object count "%s", expected "%s"' % ( - len(found), len(expected))) + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) NOVALUE = object() @@ -469,7 +504,8 @@ class AssertsExecutionResults(object): if isinstance(value, tuple): try: self.assert_unordered_result( - getattr(obj, key), value[0], *value[1]) + getattr(obj, key), value[0], *value[1] + ) except AssertionError: return False else: @@ -484,8 +520,9 @@ class AssertsExecutionResults(object): break else: fail( - "Expected %s instance with attributes %s not found." % ( - cls.__name__, repr(expected_item))) + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) return True def sql_execution_asserter(self, db=None): @@ -505,9 +542,9 @@ class AssertsExecutionResults(object): newrules = [] for rule in rules: if isinstance(rule, dict): - newrule = assertsql.AllOf(*[ - assertsql.CompiledSQL(k, v) for k, v in rule.items() - ]) + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) else: newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) @@ -516,7 +553,8 @@ class AssertsExecutionResults(object): def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( - db, callable_, assertsql.CountStatements(count)) + db, callable_, assertsql.CountStatements(count) + ) def assert_multiple_sql_count(self, dbs, callable_, counts): recs = [ diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 7a525589d..d8e924cb6 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -26,8 +26,10 @@ class AssertRule(object): pass def no_more_statements(self): - assert False, 'All statements are complete, but pending '\ - 'assertion rules remain' + assert False, ( + "All statements are complete, but pending " + "assertion rules remain" + ) class SQLMatchRule(AssertRule): @@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule): def process_statement(self, execute_observed): stmt = execute_observed.statements[0] if self.statement != stmt.statement or ( - self.params is not None and self.params != stmt.parameters): - self.errormessage = \ - "Testing for exact SQL %s parameters %s received %s %s" % ( - self.statement, self.params, - stmt.statement, stmt.parameters + self.params is not None and self.params != stmt.parameters + ): + self.errormessage = ( + "Testing for exact SQL %s parameters %s received %s %s" + % ( + self.statement, + self.params, + stmt.statement, + stmt.parameters, ) + ) else: execute_observed.statements.pop(0) self.is_consumed = True @@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - - def __init__(self, statement, params=None, dialect='default'): + def __init__(self, statement, params=None, dialect="default"): self.statement = statement self.params = params self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - if self.dialect == 'default': + if self.dialect == "default": return DefaultDialect() else: # ugh - if self.dialect == 'postgresql': - params = {'implicit_returning': True} + if self.dialect == "postgresql": + params = {"implicit_returning": True} else: params = {} return url.URL(self.dialect).get_dialect()(**params) @@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule): context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): - compiled = \ - context.compiled.statement.compile( - dialect=compare_dialect, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), + ) else: - compiled = ( - context.compiled.statement.compile( - dialect=compare_dialect, - column_keys=context.compiled.column_keys, - inline=context.compiled.inline, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + inline=context.compiled.inline, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), ) - _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) + _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: _received_parameters = [compiled.construct_params()] else: _received_parameters = [ - compiled.construct_params(m) for m in parameters] + compiled.construct_params(m) for m in parameters + ] return _received_statement, _received_parameters def process_statement(self, execute_observed): context = execute_observed.context - _received_statement, _received_parameters = \ - self._received_statement(execute_observed) + _received_statement, _received_parameters = self._received_statement( + execute_observed + ) params = self._all_params(context) equivalent = self._compare_sql(execute_observed, _received_statement) @@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule): for param_key in param: # a key in param did not match current # 'received' - if param_key not in received or \ - received[param_key] != param[param_key]: + if ( + param_key not in received + or received[param_key] != param[param_key] + ): break else: # all keys in param matched 'received'; @@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule): self.errormessage = None else: self.errormessage = self._failure_message(params) % { - 'received_statement': _received_statement, - 'received_parameters': _received_parameters + "received_statement": _received_statement, + "received_parameters": _received_parameters, } def _all_params(self, context): @@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule): def _failure_message(self, expected_params): return ( - 'Testing for compiled statement %r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.statement.replace('%', '%%'), expected_params - ) + "Testing for compiled statement %r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" + % (self.statement.replace("%", "%%"), expected_params) ) @@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL): self.regex = re.compile(regex) self.orig_regex = regex self.params = params - self.dialect = 'default' + self.dialect = "default" def _failure_message(self, expected_params): return ( - 'Testing for compiled statement ~%r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.orig_regex, expected_params - ) + "Testing for compiled statement ~%r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" % (self.orig_regex, expected_params) ) def _compare_sql(self, execute_observed, received_statement): @@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL): return execute_observed.context.dialect def _compare_no_space(self, real_stmt, received_stmt): - stmt = re.sub(r'[\n\t]', '', real_stmt) + stmt = re.sub(r"[\n\t]", "", real_stmt) return received_stmt == stmt def _received_statement(self, execute_observed): - received_stmt, received_params = super(DialectSQL, self).\ - _received_statement(execute_observed) + received_stmt, received_params = super( + DialectSQL, self + )._received_statement(execute_observed) # TODO: why do we need this part? for real_stmt in execute_observed.statements: @@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL): else: raise AssertionError( "Can't locate compiled statement %r in list of " - "statements actually invoked" % received_stmt) + "statements actually invoked" % received_stmt + ) return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle - if paramstyle == 'pyformat': - stmt = re.sub( - r':([\w_]+)', r"%(\1)s", stmt) + if paramstyle == "pyformat": + stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) else: # positional params repl = None - if paramstyle == 'qmark': + if paramstyle == "qmark": repl = "?" - elif paramstyle == 'format': + elif paramstyle == "format": repl = r"%s" - elif paramstyle == 'numeric': + elif paramstyle == "numeric": repl = None - stmt = re.sub(r':([\w_]+)', repl, stmt) + stmt = re.sub(r":([\w_]+)", repl, stmt) return received_statement == stmt class CountStatements(AssertRule): - def __init__(self, count): self.count = count self._statement_count = 0 @@ -256,12 +264,13 @@ class CountStatements(AssertRule): def no_more_statements(self): if self.count != self._statement_count: - assert False, 'desired statement count %d does not match %d' \ - % (self.count, self._statement_count) + assert False, "desired statement count %d does not match %d" % ( + self.count, + self._statement_count, + ) class AllOf(AssertRule): - def __init__(self, *rules): self.rules = set(rules) @@ -283,7 +292,6 @@ class AllOf(AssertRule): class EachOf(AssertRule): - def __init__(self, *rules): self.rules = list(rules) @@ -309,7 +317,6 @@ class EachOf(AssertRule): class Or(AllOf): - def process_statement(self, execute_observed): for rule in self.rules: rule.process_statement(execute_observed) @@ -331,7 +338,8 @@ class SQLExecuteObserved(object): class SQLCursorExecuteObserved( collections.namedtuple( "SQLCursorExecuteObserved", - ["statement", "parameters", "context", "executemany"]) + ["statement", "parameters", "context", "executemany"], + ) ): pass @@ -374,21 +382,25 @@ def assert_engine(engine): orig[:] = clauseelement, multiparams, params @event.listens_for(engine, "after_cursor_execute") - def cursor_execute(conn, cursor, statement, parameters, - context, executemany): + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): if not context: return # then grab real cursor statements and associate them all # around a single context - if asserter.accumulated and \ - asserter.accumulated[-1].context is context: + if ( + asserter.accumulated + and asserter.accumulated[-1].context is context + ): obs = asserter.accumulated[-1] else: obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) asserter.accumulated.append(obs) obs.statements.append( SQLCursorExecuteObserved( - statement, parameters, context, executemany) + statement, parameters, context, executemany + ) ) try: diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index e9cfb3de9..1ff282af5 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -64,8 +64,9 @@ class Config(object): assert _current, "Can't push without a default Config set up" cls.push( Config( - db, _current.db_opts, _current.options, _current.file_config), - namespace + db, _current.db_opts, _current.options, _current.file_config + ), + namespace, ) @classmethod @@ -94,4 +95,3 @@ class Config(object): def skip_test(msg): raise _skip_test_exception(msg) - diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index d17e30edf..074e3b338 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -16,7 +16,6 @@ import warnings class ConnectionKiller(object): - def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() self.testing_engines = weakref.WeakKeyDictionary() @@ -39,8 +38,8 @@ class ConnectionKiller(object): fn() except Exception as e: warnings.warn( - "testing_reaper couldn't " - "rollback/close connection: %s" % e) + "testing_reaper couldn't " "rollback/close connection: %s" % e + ) def rollback_all(self): for rec in list(self.proxy_refs): @@ -97,18 +96,19 @@ class ConnectionKiller(object): if rec.is_valid: assert False + testing_reaper = ConnectionKiller() def drop_all_tables(metadata, bind): testing_reaper.close_all() - if hasattr(bind, 'close'): + if hasattr(bind, "close"): bind.close() if not config.db.dialect.supports_alter: from . import assertions - with assertions.expect_warnings( - "Can't sort tables", assert_=False): + + with assertions.expect_warnings("Can't sort tables", assert_=False): metadata.drop_all(bind) else: metadata.drop_all(bind) @@ -151,19 +151,20 @@ def close_open_connections(fn, *args, **kw): def all_dialects(exclude=None): import sqlalchemy.databases as d + for name in d.__all__: # TEMPORARY if exclude and name in exclude: continue mod = getattr(d, name, None) if not mod: - mod = getattr(__import__( - 'sqlalchemy.databases.%s' % name).databases, name) + mod = getattr( + __import__("sqlalchemy.databases.%s" % name).databases, name + ) yield mod.dialect() class ReconnectFixture(object): - def __init__(self, dbapi): self.dbapi = dbapi self.connections = [] @@ -191,8 +192,8 @@ class ReconnectFixture(object): fn() except Exception as e: warnings.warn( - "ReconnectFixture couldn't " - "close connection: %s" % e) + "ReconnectFixture couldn't " "close connection: %s" % e + ) def shutdown(self, stop=False): # TODO: this doesn't cover all cases @@ -214,7 +215,7 @@ def reconnecting_engine(url=None, options=None): dbapi = config.db.dialect.dbapi if not options: options = {} - options['module'] = ReconnectFixture(dbapi) + options["module"] = ReconnectFixture(dbapi) engine = testing_engine(url, options) _dispose = engine.dispose @@ -238,7 +239,7 @@ def testing_engine(url=None, options=None): if not options: use_reaper = True else: - use_reaper = options.pop('use_reaper', True) + use_reaper = options.pop("use_reaper", True) url = url or config.db.url @@ -253,15 +254,15 @@ def testing_engine(url=None, options=None): default_opt.update(options) engine = create_engine(url, **options) - engine._has_events = True # enable event blocks, helps with profiling + engine._has_events = True # enable event blocks, helps with profiling if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 engine.pool._max_overflow = 0 if use_reaper: - event.listen(engine.pool, 'connect', testing_reaper.connect) - event.listen(engine.pool, 'checkout', testing_reaper.checkout) - event.listen(engine.pool, 'invalidate', testing_reaper.invalidate) + event.listen(engine.pool, "connect", testing_reaper.connect) + event.listen(engine.pool, "checkout", testing_reaper.checkout) + event.listen(engine.pool, "invalidate", testing_reaper.invalidate) testing_reaper.add_engine(engine) return engine @@ -290,19 +291,17 @@ def mock_engine(dialect_name=None): buffer.append(sql) def assert_sql(stmts): - recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] assert recv == stmts, recv def print_sql(): d = engine.dialect - return "\n".join( - str(s.compile(dialect=d)) - for s in engine.mock - ) - - engine = create_engine(dialect_name + '://', - strategy='mock', executor=executor) - assert not hasattr(engine, 'mock') + return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + + engine = create_engine( + dialect_name + "://", strategy="mock", executor=executor + ) + assert not hasattr(engine, "mock") engine.mock = buffer engine.assert_sql = assert_sql engine.print_sql = print_sql @@ -358,14 +357,15 @@ class DBAPIProxyConnection(object): return getattr(self.conn, key) -def proxying_engine(conn_cls=DBAPIProxyConnection, - cursor_cls=DBAPIProxyCursor): +def proxying_engine( + conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor +): """Produce an engine that provides proxy hooks for common methods. """ + def mock_conn(): return conn_cls(config.db, cursor_cls) - return testing_engine(options={'creator': mock_conn}) - + return testing_engine(options={"creator": mock_conn}) diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index b634735fc..42c42149c 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -12,7 +12,6 @@ _repr_stack = set() class BasicEntity(object): - def __init__(self, **kw): for key, value in kw.items(): setattr(self, key, value) @@ -24,17 +23,22 @@ class BasicEntity(object): try: return "%s(%s)" % ( (self.__class__.__name__), - ', '.join(["%s=%r" % (key, getattr(self, key)) - for key in sorted(self.__dict__.keys()) - if not key.startswith('_')])) + ", ".join( + [ + "%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith("_") + ] + ), + ) finally: _repr_stack.remove(id(self)) + _recursion_stack = set() class ComparableEntity(BasicEntity): - def __hash__(self): return hash(self.__class__) @@ -75,7 +79,7 @@ class ComparableEntity(BasicEntity): b = other for attr in list(a.__dict__): - if attr.startswith('_'): + if attr.startswith("_"): continue value = getattr(a, attr) @@ -85,9 +89,10 @@ class ComparableEntity(BasicEntity): except (AttributeError, sa_exc.UnboundExecutionError): return False - if hasattr(value, '__iter__'): - if hasattr(value, '__getitem__') and not hasattr( - value, 'keys'): + if hasattr(value, "__iter__"): + if hasattr(value, "__getitem__") and not hasattr( + value, "keys" + ): if list(value) != list(battr): return False else: diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 512fffb3b..9ed9e42c3 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -16,6 +16,7 @@ from . import config from .. import util from ..util import decorator + def skip_if(predicate, reason=None): rule = compound() pred = _as_predicate(predicate, reason) @@ -70,15 +71,15 @@ class compound(object): def matching_config_reasons(self, config): return [ - predicate._as_string(config) for predicate - in self.skips.union(self.fails) + predicate._as_string(config) + for predicate in self.skips.union(self.fails) if predicate(config) ] def include_test(self, include_tags, exclude_tags): return bool( - not self.tags.intersection(exclude_tags) and - (not include_tags or self.tags.intersection(include_tags)) + not self.tags.intersection(exclude_tags) + and (not include_tags or self.tags.intersection(include_tags)) ) def _extend(self, other): @@ -87,13 +88,14 @@ class compound(object): self.tags.update(other.tags) def __call__(self, fn): - if hasattr(fn, '_sa_exclusion_extend'): + if hasattr(fn, "_sa_exclusion_extend"): fn._sa_exclusion_extend._extend(self) return fn @decorator def decorate(fn, *args, **kw): return self._do(config._current, fn, *args, **kw) + decorated = decorate(fn) decorated._sa_exclusion_extend = self return decorated @@ -113,10 +115,7 @@ class compound(object): def _do(self, cfg, fn, *args, **kw): for skip in self.skips: if skip(cfg): - msg = "'%s' : %s" % ( - fn.__name__, - skip._as_string(cfg) - ) + msg = "'%s' : %s" % (fn.__name__, skip._as_string(cfg)) config.skip_test(msg) try: @@ -127,16 +126,20 @@ class compound(object): self._expect_success(cfg, name=fn.__name__) return return_value - def _expect_failure(self, config, ex, name='block'): + def _expect_failure(self, config, ex, name="block"): for fail in self.fails: if fail(config): - print(("%s failed as expected (%s): %s " % ( - name, fail._as_string(config), str(ex)))) + print( + ( + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), str(ex)) + ) + ) break else: util.raise_from_cause(ex) - def _expect_success(self, config, name='block'): + def _expect_success(self, config, name="block"): if not self.fails: return for fail in self.fails: @@ -144,13 +147,12 @@ class compound(object): break else: raise AssertionError( - "Unexpected success for '%s' (%s)" % - ( + "Unexpected success for '%s' (%s)" + % ( name, " and ".join( - fail._as_string(config) - for fail in self.fails - ) + fail._as_string(config) for fail in self.fails + ), ) ) @@ -186,21 +188,24 @@ class Predicate(object): return predicate elif isinstance(predicate, (list, set)): return OrPredicate( - [cls.as_predicate(pred) for pred in predicate], - description) + [cls.as_predicate(pred) for pred in predicate], description + ) elif isinstance(predicate, tuple): return SpecPredicate(*predicate) elif isinstance(predicate, util.string_types): tokens = re.match( - r'([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?', predicate) + r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate + ) if not tokens: raise ValueError( - "Couldn't locate DB name in predicate: %r" % predicate) + "Couldn't locate DB name in predicate: %r" % predicate + ) db = tokens.group(1) op = tokens.group(2) spec = ( tuple(int(d) for d in tokens.group(3).split(".")) - if tokens.group(3) else None + if tokens.group(3) + else None ) return SpecPredicate(db, op, spec, description=description) @@ -215,11 +220,13 @@ class Predicate(object): bool_ = not negate return self.description % { "driver": config.db.url.get_driver_name() - if config else "<no driver>", + if config + else "<no driver>", "database": config.db.url.get_backend_name() - if config else "<no database>", + if config + else "<no database>", "doesnt_support": "doesn't support" if bool_ else "does support", - "does_support": "does support" if bool_ else "doesn't support" + "does_support": "does support" if bool_ else "doesn't support", } def _as_string(self, config=None, negate=False): @@ -246,21 +253,21 @@ class SpecPredicate(Predicate): self.description = description _ops = { - '<': operator.lt, - '>': operator.gt, - '==': operator.eq, - '!=': operator.ne, - '<=': operator.le, - '>=': operator.ge, - 'in': operator.contains, - 'between': lambda val, pair: val >= pair[0] and val <= pair[1], + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + "!=": operator.ne, + "<=": operator.le, + ">=": operator.ge, + "in": operator.contains, + "between": lambda val, pair: val >= pair[0] and val <= pair[1], } def __call__(self, config): engine = config.db if "+" in self.db: - dialect, driver = self.db.split('+') + dialect, driver = self.db.split("+") else: dialect, driver = self.db, None @@ -273,8 +280,9 @@ class SpecPredicate(Predicate): assert driver is None, "DBAPI version specs not supported yet" version = _server_version(engine) - oper = hasattr(self.op, '__call__') and self.op \ - or self._ops[self.op] + oper = ( + hasattr(self.op, "__call__") and self.op or self._ops[self.op] + ) return oper(version, self.spec) else: return True @@ -289,17 +297,9 @@ class SpecPredicate(Predicate): return "%s" % self.db else: if negate: - return "not %s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "not %s %s %s" % (self.db, self.op, self.spec) else: - return "%s %s %s" % ( - self.db, - self.op, - self.spec - ) + return "%s %s %s" % (self.db, self.op, self.spec) class LambdaPredicate(Predicate): @@ -356,8 +356,9 @@ class OrPredicate(Predicate): conjunction = " and " else: conjunction = " or " - return conjunction.join(p._as_string(config, negate=negate) - for p in self.predicates) + return conjunction.join( + p._as_string(config, negate=negate) for p in self.predicates + ) def _negation_str(self, config): if self.description is not None: @@ -387,7 +388,7 @@ def _server_version(engine): # force metadata to be retrieved conn = engine.connect() - version = getattr(engine.dialect, 'server_version_info', None) + version = getattr(engine.dialect, "server_version_info", None) if version is None: version = () conn.close() @@ -395,9 +396,7 @@ def _server_version(engine): def db_spec(*dbs): - return OrPredicate( - [Predicate.as_predicate(db) for db in dbs] - ) + return OrPredicate([Predicate.as_predicate(db) for db in dbs]) def open(): @@ -422,11 +421,7 @@ def fails_on(db, reason=None): def fails_on_everything_except(*dbs): - return succeeds_if( - OrPredicate([ - Predicate.as_predicate(db) for db in dbs - ]) - ) + return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) def skip(db, reason=None): @@ -435,8 +430,9 @@ def skip(db, reason=None): def only_on(dbs, reason=None): return only_if( - OrPredicate([Predicate.as_predicate(db, reason) - for db in util.to_list(dbs)]) + OrPredicate( + [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)] + ) ) @@ -446,7 +442,6 @@ def exclude(db, op, spec, reason=None): def against(config, *queries): assert queries, "no queries sent!" - return OrPredicate([ - Predicate.as_predicate(query) - for query in queries - ])(config) + return OrPredicate([Predicate.as_predicate(query) for query in queries])( + config + ) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index dd0fa5a48..98184cdd4 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -54,19 +54,19 @@ class TestBase(object): class TablesTest(TestBase): # 'once', None - run_setup_bind = 'once' + run_setup_bind = "once" # 'once', 'each', None - run_define_tables = 'once' + run_define_tables = "once" # 'once', 'each', None - run_create_tables = 'once' + run_create_tables = "once" # 'once', 'each', None - run_inserts = 'each' + run_inserts = "each" # 'each', None - run_deletes = 'each' + run_deletes = "each" # 'once', None run_dispose_bind = None @@ -86,10 +86,10 @@ class TablesTest(TestBase): @classmethod def _init_class(cls): - if cls.run_define_tables == 'each': - if cls.run_create_tables == 'once': - cls.run_create_tables = 'each' - assert cls.run_inserts in ('each', None) + if cls.run_define_tables == "each": + if cls.run_create_tables == "once": + cls.run_create_tables = "each" + assert cls.run_inserts in ("each", None) cls.other = adict() cls.tables = adict() @@ -100,40 +100,40 @@ class TablesTest(TestBase): @classmethod def _setup_once_inserts(cls): - if cls.run_inserts == 'once': + if cls.run_inserts == "once": cls._load_fixtures() cls.insert_data() @classmethod def _setup_once_tables(cls): - if cls.run_define_tables == 'once': + if cls.run_define_tables == "once": cls.define_tables(cls.metadata) - if cls.run_create_tables == 'once': + if cls.run_create_tables == "once": cls.metadata.create_all(cls.bind) cls.tables.update(cls.metadata.tables) def _setup_each_tables(self): - if self.run_define_tables == 'each': + if self.run_define_tables == "each": self.tables.clear() - if self.run_create_tables == 'each': + if self.run_create_tables == "each": drop_all_tables(self.metadata, self.bind) self.metadata.clear() self.define_tables(self.metadata) - if self.run_create_tables == 'each': + if self.run_create_tables == "each": self.metadata.create_all(self.bind) self.tables.update(self.metadata.tables) - elif self.run_create_tables == 'each': + elif self.run_create_tables == "each": drop_all_tables(self.metadata, self.bind) self.metadata.create_all(self.bind) def _setup_each_inserts(self): - if self.run_inserts == 'each': + if self.run_inserts == "each": self._load_fixtures() self.insert_data() def _teardown_each_tables(self): # no need to run deletes if tables are recreated on setup - if self.run_define_tables != 'each' and self.run_deletes == 'each': + if self.run_define_tables != "each" and self.run_deletes == "each": with self.bind.connect() as conn: for table in reversed(self.metadata.sorted_tables): try: @@ -141,7 +141,8 @@ class TablesTest(TestBase): except sa.exc.DBAPIError as ex: util.print_( ("Error emptying table %s: %r" % (table, ex)), - file=sys.stderr) + file=sys.stderr, + ) def setup(self): self._setup_each_tables() @@ -155,7 +156,7 @@ class TablesTest(TestBase): if cls.run_create_tables: drop_all_tables(cls.metadata, cls.bind) - if cls.run_dispose_bind == 'once': + if cls.run_dispose_bind == "once": cls.dispose_bind(cls.bind) cls.metadata.bind = None @@ -173,9 +174,9 @@ class TablesTest(TestBase): @classmethod def dispose_bind(cls, bind): - if hasattr(bind, 'dispose'): + if hasattr(bind, "dispose"): bind.dispose() - elif hasattr(bind, 'close'): + elif hasattr(bind, "close"): bind.close() @classmethod @@ -212,8 +213,12 @@ class TablesTest(TestBase): continue cls.bind.execute( table.insert(), - [dict(zip(headers[table], column_values)) - for column_values in rows[table]]) + [ + dict(zip(headers[table], column_values)) + for column_values in rows[table] + ], + ) + from sqlalchemy import event @@ -236,7 +241,6 @@ class RemovesEvents(object): class _ORMTest(object): - @classmethod def teardown_class(cls): sa.orm.session.Session.close_all() @@ -249,10 +253,10 @@ class ORMTest(_ORMTest, TestBase): class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None - run_setup_classes = 'once' + run_setup_classes = "once" # 'once', 'each', None - run_setup_mappers = 'each' + run_setup_mappers = "each" classes = None @@ -292,20 +296,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): @classmethod def _setup_once_classes(cls): - if cls.run_setup_classes == 'once': + if cls.run_setup_classes == "once": cls._with_register_classes(cls.setup_classes) @classmethod def _setup_once_mappers(cls): - if cls.run_setup_mappers == 'once': + if cls.run_setup_mappers == "once": cls._with_register_classes(cls.setup_mappers) def _setup_each_mappers(self): - if self.run_setup_mappers == 'each': + if self.run_setup_mappers == "each": self._with_register_classes(self.setup_mappers) def _setup_each_classes(self): - if self.run_setup_classes == 'each': + if self.run_setup_classes == "each": self._with_register_classes(self.setup_classes) @classmethod @@ -339,11 +343,11 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # some tests create mappers in the test bodies # and will define setup_mappers as None - # clear mappers in any case - if self.run_setup_mappers != 'once': + if self.run_setup_mappers != "once": sa.orm.clear_mappers() def _teardown_each_classes(self): - if self.run_setup_classes != 'once': + if self.run_setup_classes != "once": self.classes.clear() @classmethod @@ -356,8 +360,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): class DeclarativeMappedTest(MappedTest): - run_setup_classes = 'once' - run_setup_mappers = 'once' + run_setup_classes = "once" + run_setup_mappers = "once" @classmethod def _setup_once_tables(cls): @@ -370,15 +374,16 @@ class DeclarativeMappedTest(MappedTest): class FindFixtureDeclarative(DeclarativeMeta): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls - return DeclarativeMeta.__init__( - cls, classname, bases, dict_) + return DeclarativeMeta.__init__(cls, classname, bases, dict_) class DeclarativeBasic(object): __table_cls__ = schema.Table - _DeclBase = declarative_base(metadata=cls.metadata, - metaclass=FindFixtureDeclarative, - cls=DeclarativeBasic) + _DeclBase = declarative_base( + metadata=cls.metadata, + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic, + ) cls.DeclarativeBasic = _DeclBase fn() diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py index ea0a8da82..dc530af5e 100644 --- a/lib/sqlalchemy/testing/mock.py +++ b/lib/sqlalchemy/testing/mock.py @@ -18,4 +18,5 @@ else: except ImportError: raise ImportError( "SQLAlchemy's test suite requires the " - "'mock' library as of 0.8.2.") + "'mock' library as of 0.8.2." + ) diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py index 087fc1fe6..e84cbde44 100644 --- a/lib/sqlalchemy/testing/pickleable.py +++ b/lib/sqlalchemy/testing/pickleable.py @@ -46,29 +46,28 @@ class Parent(fixtures.ComparableEntity): class Screen(object): - def __init__(self, obj, parent=None): self.obj = obj self.parent = parent class Foo(object): - def __init__(self, moredata): - self.data = 'im data' - self.stuff = 'im stuff' + self.data = "im data" + self.stuff = "im stuff" self.moredata = moredata __hash__ = object.__hash__ def __eq__(self, other): - return other.data == self.data and \ - other.stuff == self.stuff and \ - other.moredata == self.moredata + return ( + other.data == self.data + and other.stuff == self.stuff + and other.moredata == self.moredata + ) class Bar(object): - def __init__(self, x, y): self.x = x self.y = y @@ -76,35 +75,36 @@ class Bar(object): __hash__ = object.__hash__ def __eq__(self, other): - return other.__class__ is self.__class__ and \ - other.x == self.x and \ - other.y == self.y + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) def __str__(self): return "Bar(%d, %d)" % (self.x, self.y) class OldSchool: - def __init__(self, x, y): self.x = x self.y = y def __eq__(self, other): - return other.__class__ is self.__class__ and \ - other.x == self.x and \ - other.y == self.y + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) class OldSchoolWithoutCompare: - def __init__(self, x, y): self.x = x self.y = y class BarWithoutCompare(object): - def __init__(self, x, y): self.x = x self.y = y @@ -114,7 +114,6 @@ class BarWithoutCompare(object): class NotComparable(object): - def __init__(self, data): self.data = data @@ -129,7 +128,6 @@ class NotComparable(object): class BrokenComparable(object): - def __init__(self, data): self.data = data diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index 497fcb7e5..bb52c125c 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -20,20 +20,23 @@ this should be removable when Alembic targets SQLAlchemy 1.0.0. import os import sys -bootstrap_file = locals()['bootstrap_file'] -to_bootstrap = locals()['to_bootstrap'] +bootstrap_file = locals()["bootstrap_file"] +to_bootstrap = locals()["to_bootstrap"] def load_file_as_module(name): path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) if sys.version_info >= (3, 3): from importlib import machinery + mod = machinery.SourceFileLoader(name, path).load_module() else: import imp + mod = imp.load_source(name, path) return mod + if to_bootstrap == "pytest": sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") diff --git a/lib/sqlalchemy/testing/plugin/noseplugin.py b/lib/sqlalchemy/testing/plugin/noseplugin.py index 20ea61d89..0c28a5213 100644 --- a/lib/sqlalchemy/testing/plugin/noseplugin.py +++ b/lib/sqlalchemy/testing/plugin/noseplugin.py @@ -25,6 +25,7 @@ import sys from nose.plugins import Plugin import nose + fixtures = None py3k = sys.version_info >= (3, 0) @@ -33,7 +34,7 @@ py3k = sys.version_info >= (3, 0) class NoseSQLAlchemy(Plugin): enabled = True - name = 'sqla_testing' + name = "sqla_testing" score = 100 def options(self, parser, env=os.environ): @@ -41,10 +42,14 @@ class NoseSQLAlchemy(Plugin): opt = parser.add_option def make_option(name, **kw): - callback_ = kw.pop("callback", None) or kw.pop("zeroarg_callback", None) + callback_ = kw.pop("callback", None) or kw.pop( + "zeroarg_callback", None + ) if callback_: + def wrap_(option, opt_str, value, parser): callback_(opt_str, value, parser) + kw["callback"] = wrap_ opt(name, **kw) @@ -73,7 +78,7 @@ class NoseSQLAlchemy(Plugin): def wantMethod(self, fn): if py3k: - if not hasattr(fn.__self__, 'cls'): + if not hasattr(fn.__self__, "cls"): return False cls = fn.__self__.cls else: @@ -84,24 +89,24 @@ class NoseSQLAlchemy(Plugin): return plugin_base.want_class(cls) def beforeTest(self, test): - if not hasattr(test.test, 'cls'): + if not hasattr(test.test, "cls"): return plugin_base.before_test( test, test.test.cls.__module__, - test.test.cls, test.test.method.__name__) + test.test.cls, + test.test.method.__name__, + ) def afterTest(self, test): plugin_base.after_test(test) def startContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.start_test_class(ctx) def stopContext(self, ctx): - if not isinstance(ctx, type) \ - or not issubclass(ctx, fixtures.TestBase): + if not isinstance(ctx, type) or not issubclass(ctx, fixtures.TestBase): return plugin_base.stop_test_class(ctx) diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 0ffcae093..5d6bf2975 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -46,58 +46,130 @@ options = None def setup_options(make_option): - make_option("--log-info", action="callback", type="string", callback=_log, - help="turn on info logging for <LOG> (multiple OK)") - make_option("--log-debug", action="callback", - type="string", callback=_log, - help="turn on debug logging for <LOG> (multiple OK)") - make_option("--db", action="append", type="string", dest="db", - help="Use prefab database uri. Multiple OK, " - "first one is run by default.") - make_option('--dbs', action='callback', zeroarg_callback=_list_dbs, - help="List available prefab dbs") - make_option("--dburi", action="append", type="string", dest="dburi", - help="Database uri. Multiple OK, " - "first one is run by default.") - make_option("--dropfirst", action="store_true", dest="dropfirst", - help="Drop all tables in the target database first") - make_option("--backend-only", action="store_true", dest="backend_only", - help="Run only tests marked with __backend__") - make_option("--nomemory", action="store_true", dest="nomemory", - help="Don't run memory profiling tests") - make_option("--postgresql-templatedb", type="string", - help="name of template database to use for PostgreSQL " - "CREATE DATABASE (defaults to current database)") - make_option("--low-connections", action="store_true", - dest="low_connections", - help="Use a low number of distinct connections - " - "i.e. for Oracle TNS") - make_option("--write-idents", type="string", dest="write_idents", - help="write out generated follower idents to <file>, " - "when -n<num> is used") - make_option("--reversetop", action="store_true", - dest="reversetop", default=False, - help="Use a random-ordering set implementation in the ORM " - "(helps reveal dependency issues)") - make_option("--requirements", action="callback", type="string", - callback=_requirements_opt, - help="requirements class for testing, overrides setup.cfg") - make_option("--with-cdecimal", action="store_true", - dest="cdecimal", default=False, - help="Monkeypatch the cdecimal library into Python 'decimal' " - "for all tests") - make_option("--include-tag", action="callback", callback=_include_tag, - type="string", - help="Include tests with tag <tag>") - make_option("--exclude-tag", action="callback", callback=_exclude_tag, - type="string", - help="Exclude tests with tag <tag>") - make_option("--write-profiles", action="store_true", - dest="write_profiles", default=False, - help="Write/update failing profiling data.") - make_option("--force-write-profiles", action="store_true", - dest="force_write_profiles", default=False, - help="Unconditionally write/update profiling data.") + make_option( + "--log-info", + action="callback", + type="string", + callback=_log, + help="turn on info logging for <LOG> (multiple OK)", + ) + make_option( + "--log-debug", + action="callback", + type="string", + callback=_log, + help="turn on debug logging for <LOG> (multiple OK)", + ) + make_option( + "--db", + action="append", + type="string", + dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.", + ) + make_option( + "--dbs", + action="callback", + zeroarg_callback=_list_dbs, + help="List available prefab dbs", + ) + make_option( + "--dburi", + action="append", + type="string", + dest="dburi", + help="Database uri. Multiple OK, " "first one is run by default.", + ) + make_option( + "--dropfirst", + action="store_true", + dest="dropfirst", + help="Drop all tables in the target database first", + ) + make_option( + "--backend-only", + action="store_true", + dest="backend_only", + help="Run only tests marked with __backend__", + ) + make_option( + "--nomemory", + action="store_true", + dest="nomemory", + help="Don't run memory profiling tests", + ) + make_option( + "--postgresql-templatedb", + type="string", + help="name of template database to use for PostgreSQL " + "CREATE DATABASE (defaults to current database)", + ) + make_option( + "--low-connections", + action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS", + ) + make_option( + "--write-idents", + type="string", + dest="write_idents", + help="write out generated follower idents to <file>, " + "when -n<num> is used", + ) + make_option( + "--reversetop", + action="store_true", + dest="reversetop", + default=False, + help="Use a random-ordering set implementation in the ORM " + "(helps reveal dependency issues)", + ) + make_option( + "--requirements", + action="callback", + type="string", + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg", + ) + make_option( + "--with-cdecimal", + action="store_true", + dest="cdecimal", + default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' " + "for all tests", + ) + make_option( + "--include-tag", + action="callback", + callback=_include_tag, + type="string", + help="Include tests with tag <tag>", + ) + make_option( + "--exclude-tag", + action="callback", + callback=_exclude_tag, + type="string", + help="Exclude tests with tag <tag>", + ) + make_option( + "--write-profiles", + action="store_true", + dest="write_profiles", + default=False, + help="Write/update failing profiling data.", + ) + make_option( + "--force-write-profiles", + action="store_true", + dest="force_write_profiles", + default=False, + help="Unconditionally write/update profiling data.", + ) def configure_follower(follower_ident): @@ -108,6 +180,7 @@ def configure_follower(follower_ident): """ from sqlalchemy.testing import provision + provision.FOLLOWER_IDENT = follower_ident @@ -121,9 +194,9 @@ def memoize_important_follower_config(dict_): callables, so we have to just copy all of that over. """ - dict_['memoized_config'] = { - 'include_tags': include_tags, - 'exclude_tags': exclude_tags + dict_["memoized_config"] = { + "include_tags": include_tags, + "exclude_tags": exclude_tags, } @@ -134,14 +207,14 @@ def restore_important_follower_config(dict_): """ global include_tags, exclude_tags - include_tags.update(dict_['memoized_config']['include_tags']) - exclude_tags.update(dict_['memoized_config']['exclude_tags']) + include_tags.update(dict_["memoized_config"]["include_tags"]) + exclude_tags.update(dict_["memoized_config"]["exclude_tags"]) def read_config(): global file_config file_config = configparser.ConfigParser() - file_config.read(['setup.cfg', 'test.cfg']) + file_config.read(["setup.cfg", "test.cfg"]) def pre_begin(opt): @@ -155,6 +228,7 @@ def pre_begin(opt): def set_coverage_flag(value): options.has_coverage = value + _skip_test_exception = None @@ -171,34 +245,33 @@ def post_begin(): # late imports, has to happen after config as well # as nose plugins like coverage - global util, fixtures, engines, exclusions, \ - assertions, warnings, profiling,\ - config, testing - from sqlalchemy import testing # noqa + global util, fixtures, engines, exclusions, assertions, warnings, profiling, config, testing + from sqlalchemy import testing # noqa from sqlalchemy.testing import fixtures, engines, exclusions # noqa - from sqlalchemy.testing import assertions, warnings, profiling # noqa + from sqlalchemy.testing import assertions, warnings, profiling # noqa from sqlalchemy.testing import config # noqa from sqlalchemy import util # noqa - warnings.setup_filters() + warnings.setup_filters() def _log(opt_str, value, parser): global logging if not logging: import logging + logging.basicConfig() - if opt_str.endswith('-info'): + if opt_str.endswith("-info"): logging.getLogger(value).setLevel(logging.INFO) - elif opt_str.endswith('-debug'): + elif opt_str.endswith("-debug"): logging.getLogger(value).setLevel(logging.DEBUG) def _list_dbs(*args): print("Available --db options (use --dburi to override)") - for macro in sorted(file_config.options('db')): - print("%20s\t%s" % (macro, file_config.get('db', macro))) + for macro in sorted(file_config.options("db")): + print("%20s\t%s" % (macro, file_config.get("db", macro))) sys.exit(0) @@ -207,11 +280,12 @@ def _requirements_opt(opt_str, value, parser): def _exclude_tag(opt_str, value, parser): - exclude_tags.add(value.replace('-', '_')) + exclude_tags.add(value.replace("-", "_")) def _include_tag(opt_str, value, parser): - include_tags.add(value.replace('-', '_')) + include_tags.add(value.replace("-", "_")) + pre_configure = [] post_configure = [] @@ -243,7 +317,8 @@ def _set_nomemory(opt, file_config): def _monkeypatch_cdecimal(options, file_config): if options.cdecimal: import cdecimal - sys.modules['decimal'] = cdecimal + + sys.modules["decimal"] = cdecimal @post @@ -266,27 +341,28 @@ def _engine_uri(options, file_config): if options.db: for db_token in options.db: - for db in re.split(r'[,\s]+', db_token): - if db not in file_config.options('db'): + for db in re.split(r"[,\s]+", db_token): + if db not in file_config.options("db"): raise RuntimeError( "Unknown URI specifier '%s'. " - "Specify --dbs for known uris." - % db) + "Specify --dbs for known uris." % db + ) else: - db_urls.append(file_config.get('db', db)) + db_urls.append(file_config.get("db", db)) if not db_urls: - db_urls.append(file_config.get('db', 'default')) + db_urls.append(file_config.get("db", "default")) config._current = None for db_url in db_urls: - if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': + if options.write_idents and provision.FOLLOWER_IDENT: # != 'master': with open(options.write_idents, "a") as file_: file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n") cfg = provision.setup_config( - db_url, options, file_config, provision.FOLLOWER_IDENT) + db_url, options, file_config, provision.FOLLOWER_IDENT + ) if not config._current: cfg.set_as_current(cfg, testing) @@ -295,7 +371,7 @@ def _engine_uri(options, file_config): @post def _requirements(options, file_config): - requirement_cls = file_config.get('sqla_testing', "requirement_cls") + requirement_cls = file_config.get("sqla_testing", "requirement_cls") _setup_requirements(requirement_cls) @@ -334,22 +410,28 @@ def _prep_testing_database(options, file_config): pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData()) - )) + e.execute( + schema._DropView( + schema.Table(vname, schema.MetaData()) + ) + ) if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names( - schema="test_schema") + view_names = inspector.get_view_names(schema="test_schema") except NotImplementedError: pass else: for vname in view_names: - e.execute(schema._DropView( - schema.Table(vname, schema.MetaData(), - schema="test_schema") - )) + e.execute( + schema._DropView( + schema.Table( + vname, + schema.MetaData(), + schema="test_schema", + ) + ) + ) util.drop_all_tables(e, inspector) @@ -358,23 +440,29 @@ def _prep_testing_database(options, file_config): if against(cfg, "postgresql"): from sqlalchemy.dialects import postgresql + for enum in inspector.get_enums("*"): - e.execute(postgresql.DropEnumType( - postgresql.ENUM( - name=enum['name'], - schema=enum['schema']))) + e.execute( + postgresql.DropEnumType( + postgresql.ENUM( + name=enum["name"], schema=enum["schema"] + ) + ) + ) @post def _reverse_topological(options, file_config): if options.reversetop: from sqlalchemy.orm.util import randomize_unitofwork + randomize_unitofwork() @post def _post_setup_options(opt, file_config): from sqlalchemy.testing import config + config.options = options config.file_config = file_config @@ -382,17 +470,20 @@ def _post_setup_options(opt, file_config): @post def _setup_profiling(options, file_config): from sqlalchemy.testing import profiling + profiling._profile_stats = profiling.ProfileStatsFile( - file_config.get('sqla_testing', 'profile_file')) + file_config.get("sqla_testing", "profile_file") + ) def want_class(cls): if not issubclass(cls, fixtures.TestBase): return False - elif cls.__name__.startswith('_'): + elif cls.__name__.startswith("_"): return False - elif config.options.backend_only and not getattr(cls, '__backend__', - False): + elif config.options.backend_only and not getattr( + cls, "__backend__", False + ): return False else: return True @@ -405,25 +496,28 @@ def want_method(cls, fn): return False elif include_tags: return ( - hasattr(cls, '__tags__') and - exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) + hasattr(cls, "__tags__") + and exclusions.tags(cls.__tags__).include_test( + include_tags, exclude_tags + ) ) or ( - hasattr(fn, '_sa_exclusion_extend') and - fn._sa_exclusion_extend.include_test( - include_tags, exclude_tags) + hasattr(fn, "_sa_exclusion_extend") + and fn._sa_exclusion_extend.include_test( + include_tags, exclude_tags + ) ) - elif exclude_tags and hasattr(cls, '__tags__'): + elif exclude_tags and hasattr(cls, "__tags__"): return exclusions.tags(cls.__tags__).include_test( - include_tags, exclude_tags) - elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'): + include_tags, exclude_tags + ) + elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"): return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags) else: return True def generate_sub_tests(cls, module): - if getattr(cls, '__backend__', False): + if getattr(cls, "__backend__", False): for cfg in _possible_configs_for_cls(cls): orig_name = cls.__name__ @@ -431,16 +525,13 @@ def generate_sub_tests(cls, module): # pytest junit plugin, which is tripped up by the brackets # and periods, so sanitize - alpha_name = re.sub('[_\[\]\.]+', '_', cfg.name) - alpha_name = re.sub('_+$', '', alpha_name) + alpha_name = re.sub("[_\[\]\.]+", "_", cfg.name) + alpha_name = re.sub("_+$", "", alpha_name) name = "%s_%s" % (cls.__name__, alpha_name) subcls = type( name, - (cls, ), - { - "_sa_orig_cls_name": orig_name, - "__only_on_config__": cfg - } + (cls,), + {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg}, ) setattr(module, name, subcls) yield subcls @@ -454,8 +545,8 @@ def start_test_class(cls): def stop_test_class(cls): - #from sqlalchemy import inspect - #assert not inspect(testing.db).get_table_names() + # from sqlalchemy import inspect + # assert not inspect(testing.db).get_table_names() engines.testing_reaper._stop_test_ctx() try: if not options.low_connections: @@ -475,7 +566,7 @@ def final_process_cleanup(): def _setup_engine(cls): - if getattr(cls, '__engine_options__', None): + if getattr(cls, "__engine_options__", None): eng = engines.testing_engine(options=cls.__engine_options__) config._current.push_engine(eng, testing) @@ -485,7 +576,7 @@ def before_test(test, test_module_name, test_class, test_name): # like a nose id, e.g.: # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause" - name = getattr(test_class, '_sa_orig_cls_name', test_class.__name__) + name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__) id_ = "%s.%s.%s" % (test_module_name, name, test_name) @@ -505,16 +596,16 @@ def _possible_configs_for_cls(cls, reasons=None): if spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on__', None): + if getattr(cls, "__only_on__", None): spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) for config_obj in list(all_configs): if not spec(config_obj): all_configs.remove(config_obj) - if getattr(cls, '__only_on_config__', None): + if getattr(cls, "__only_on_config__", None): all_configs.intersection_update([cls.__only_on_config__]) - if hasattr(cls, '__requires__'): + if hasattr(cls, "__requires__"): requirements = config.requirements for config_obj in list(all_configs): for requirement in cls.__requires__: @@ -527,7 +618,7 @@ def _possible_configs_for_cls(cls, reasons=None): reasons.extend(skip_reasons) break - if hasattr(cls, '__prefer_requires__'): + if hasattr(cls, "__prefer_requires__"): non_preferred = set() requirements = config.requirements for config_obj in list(all_configs): @@ -546,30 +637,32 @@ def _do_skips(cls): reasons = [] all_configs = _possible_configs_for_cls(cls, reasons) - if getattr(cls, '__skip_if__', False): - for c in getattr(cls, '__skip_if__'): + if getattr(cls, "__skip_if__", False): + for c in getattr(cls, "__skip_if__"): if c(): - config.skip_test("'%s' skipped by %s" % ( - cls.__name__, c.__name__) + config.skip_test( + "'%s' skipped by %s" % (cls.__name__, c.__name__) ) if not all_configs: msg = "'%s' unsupported on any DB implementation %s%s" % ( cls.__name__, ", ".join( - "'%s(%s)+%s'" % ( + "'%s(%s)+%s'" + % ( config_obj.db.name, ".".join( - str(dig) for dig in - exclusions._server_version(config_obj.db)), - config_obj.db.driver + str(dig) + for dig in exclusions._server_version(config_obj.db) + ), + config_obj.db.driver, ) - for config_obj in config.Config.all_configs() + for config_obj in config.Config.all_configs() ), - ", ".join(reasons) + ", ".join(reasons), ) config.skip_test(msg) - elif hasattr(cls, '__prefer_backends__'): + elif hasattr(cls, "__prefer_backends__"): non_preferred = set() spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) for config_obj in all_configs: diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index da682ea00..fd0a48462 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -13,6 +13,7 @@ import os try: import xdist # noqa + has_xdist = True except ImportError: has_xdist = False @@ -24,30 +25,42 @@ def pytest_addoption(parser): def make_option(name, **kw): callback_ = kw.pop("callback", None) if callback_: + class CallableAction(argparse.Action): - def __call__(self, parser, namespace, - values, option_string=None): + def __call__( + self, parser, namespace, values, option_string=None + ): callback_(option_string, values, parser) + kw["action"] = CallableAction zeroarg_callback = kw.pop("zeroarg_callback", None) if zeroarg_callback: + class CallableAction(argparse.Action): - def __init__(self, option_strings, - dest, default=False, - required=False, help=None): - super(CallableAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=True, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, - values, option_string=None): + def __init__( + self, + option_strings, + dest, + default=False, + required=False, + help=None, + ): + super(CallableAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + ) + + def __call__( + self, parser, namespace, values, option_string=None + ): zeroarg_callback(option_string, values, parser) + kw["action"] = CallableAction group.addoption(name, **kw) @@ -59,18 +72,18 @@ def pytest_addoption(parser): def pytest_configure(config): if hasattr(config, "slaveinput"): plugin_base.restore_important_follower_config(config.slaveinput) - plugin_base.configure_follower( - config.slaveinput["follower_ident"] - ) + plugin_base.configure_follower(config.slaveinput["follower_ident"]) else: - if config.option.write_idents and \ - os.path.exists(config.option.write_idents): + if config.option.write_idents and os.path.exists( + config.option.write_idents + ): os.remove(config.option.write_idents) plugin_base.pre_begin(config.option) - plugin_base.set_coverage_flag(bool(getattr(config.option, - "cov_source", False))) + plugin_base.set_coverage_flag( + bool(getattr(config.option, "cov_source", False)) + ) plugin_base.set_skip_test(pytest.skip.Exception) @@ -94,10 +107,12 @@ if has_xdist: node.slaveinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] from sqlalchemy.testing import provision + provision.create_follower_db(node.slaveinput["follower_ident"]) def pytest_testnodedown(node, error): from sqlalchemy.testing import provision + provision.drop_follower_db(node.slaveinput["follower_ident"]) @@ -114,19 +129,22 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict(list) items[:] = [ - item for item in - items if isinstance(item.parent, pytest.Instance) - and not item.parent.parent.name.startswith("_")] + item + for item in items + if isinstance(item.parent, pytest.Instance) + and not item.parent.parent.name.startswith("_") + ] test_classes = set(item.parent for item in items) for test_class in test_classes: for sub_cls in plugin_base.generate_sub_tests( - test_class.cls, test_class.parent.module): + test_class.cls, test_class.parent.module + ): if sub_cls is not test_class.cls: list_ = rebuilt_items[test_class.cls] for inst in pytest.Class( - sub_cls.__name__, - parent=test_class.parent.parent).collect(): + sub_cls.__name__, parent=test_class.parent.parent + ).collect(): list_.extend(inst.collect()) newitems = [] @@ -139,23 +157,29 @@ def pytest_collection_modifyitems(session, config, items): # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) - items[:] = sorted(newitems, key=lambda item: ( - item.parent.parent.parent.name, - item.parent.parent.name, - item.name - )) + items[:] = sorted( + newitems, + key=lambda item: ( + item.parent.parent.parent.name, + item.parent.parent.name, + item.name, + ), + ) def pytest_pycollect_makeitem(collector, name, obj): if inspect.isclass(obj) and plugin_base.want_class(obj): return pytest.Class(name, parent=collector) - elif inspect.isfunction(obj) and \ - isinstance(collector, pytest.Instance) and \ - plugin_base.want_method(collector.cls, obj): + elif ( + inspect.isfunction(obj) + and isinstance(collector, pytest.Instance) + and plugin_base.want_method(collector.cls, obj) + ): return pytest.Function(name, parent=collector) else: return [] + _current_class = None @@ -180,6 +204,7 @@ def pytest_runtest_setup(item): global _current_class class_teardown(item.parent.parent) _current_class = None + item.parent.parent.addfinalizer(finalize) test_setup(item) @@ -194,8 +219,9 @@ def pytest_runtest_teardown(item): def test_setup(item): - plugin_base.before_test(item, item.parent.module.__name__, - item.parent.cls, item.name) + plugin_base.before_test( + item, item.parent.module.__name__, item.parent.cls, item.name + ) def test_teardown(item): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index fab99b186..3986985c7 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -42,17 +42,16 @@ class ProfileStatsFile(object): def __init__(self, filename): self.force_write = ( - config.options is not None and - config.options.force_write_profiles + config.options is not None and config.options.force_write_profiles ) self.write = self.force_write or ( - config.options is not None and - config.options.write_profiles + config.options is not None and config.options.write_profiles ) self.fname = os.path.abspath(filename) self.short_fname = os.path.split(self.fname)[-1] self.data = collections.defaultdict( - lambda: collections.defaultdict(dict)) + lambda: collections.defaultdict(dict) + ) self._read() if self.write: # rewrite for the case where features changed, @@ -65,7 +64,7 @@ class ProfileStatsFile(object): dbapi_key = config.db.name + "_" + config.db.driver # keep it at 2.7, 3.1, 3.2, etc. for now. - py_version = '.'.join([str(v) for v in sys.version_info[0:2]]) + py_version = ".".join([str(v) for v in sys.version_info[0:2]]) platform_tokens = [py_version] platform_tokens.append(dbapi_key) @@ -87,8 +86,7 @@ class ProfileStatsFile(object): def has_stats(self): test_key = _current_test return ( - test_key in self.data and - self.platform_key in self.data[test_key] + test_key in self.data and self.platform_key in self.data[test_key] ) def result(self, callcount): @@ -96,15 +94,15 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] - if 'counts' not in per_platform: - per_platform['counts'] = counts = [] + if "counts" not in per_platform: + per_platform["counts"] = counts = [] else: - counts = per_platform['counts'] + counts = per_platform["counts"] - if 'current_count' not in per_platform: - per_platform['current_count'] = current_count = 0 + if "current_count" not in per_platform: + per_platform["current_count"] = current_count = 0 else: - current_count = per_platform['current_count'] + current_count = per_platform["current_count"] has_count = len(counts) > current_count @@ -114,16 +112,16 @@ class ProfileStatsFile(object): self._write() result = None else: - result = per_platform['lineno'], counts[current_count] - per_platform['current_count'] += 1 + result = per_platform["lineno"], counts[current_count] + per_platform["current_count"] += 1 return result def replace(self, callcount): test_key = _current_test per_fn = self.data[test_key] per_platform = per_fn[self.platform_key] - counts = per_platform['counts'] - current_count = per_platform['current_count'] + counts = per_platform["counts"] + current_count = per_platform["current_count"] if current_count < len(counts): counts[current_count - 1] = callcount else: @@ -164,9 +162,9 @@ class ProfileStatsFile(object): per_fn = self.data[test_key] per_platform = per_fn[platform_key] c = [int(count) for count in counts.split(",")] - per_platform['counts'] = c - per_platform['lineno'] = lineno + 1 - per_platform['current_count'] = 0 + per_platform["counts"] = c + per_platform["lineno"] = lineno + 1 + per_platform["current_count"] = 0 profile_f.close() def _write(self): @@ -179,7 +177,7 @@ class ProfileStatsFile(object): profile_f.write("\n# TEST: %s\n\n" % test_key) for platform_key in sorted(per_fn): per_platform = per_fn[platform_key] - c = ",".join(str(count) for count in per_platform['counts']) + c = ",".join(str(count) for count in per_platform["counts"]) profile_f.write("%s %s %s\n" % (test_key, platform_key, c)) profile_f.close() @@ -199,7 +197,9 @@ def function_call_count(variance=0.05): def wrap(*args, **kw): with count_functions(variance=variance): return fn(*args, **kw) + return update_wrapper(wrap, fn) + return decorate @@ -213,21 +213,22 @@ def count_functions(variance=0.05): "No profiling stats available on this " "platform for this function. Run tests with " "--write-profiles to add statistics to %s for " - "this platform." % _profile_stats.short_fname) + "this platform." % _profile_stats.short_fname + ) gc_collect() pr = cProfile.Profile() pr.enable() - #began = time.time() + # began = time.time() yield - #ended = time.time() + # ended = time.time() pr.disable() - #s = compat.StringIO() + # s = compat.StringIO() stats = pstats.Stats(pr, stream=sys.stdout) - #timespent = ended - began + # timespent = ended - began callcount = stats.total_calls expected = _profile_stats.result(callcount) @@ -237,11 +238,7 @@ def count_functions(variance=0.05): else: line_no, expected_count = expected - print(("Pstats calls: %d Expected %s" % ( - callcount, - expected_count - ) - )) + print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) stats.sort_stats("cumulative") stats.print_stats() @@ -259,7 +256,9 @@ def count_functions(variance=0.05): "--write-profiles to " "regenerate this callcount." % ( - callcount, (variance * 100), - expected_count, _profile_stats.platform_key)) - - + callcount, + (variance * 100), + expected_count, + _profile_stats.platform_key, + ) + ) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index c0ca7c1cb..25028ccb3 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -8,6 +8,7 @@ import collections import os import time import logging + log = logging.getLogger(__name__) FOLLOWER_IDENT = None @@ -25,6 +26,7 @@ class register(object): def decorate(fn): self.fns[dbname] = fn return self + return decorate def __call__(self, cfg, *arg): @@ -38,7 +40,7 @@ class register(object): if backend in self.fns: return self.fns[backend](cfg, *arg) else: - return self.fns['*'](cfg, *arg) + return self.fns["*"](cfg, *arg) def create_follower_db(follower_ident): @@ -82,9 +84,7 @@ def _configs_for_db_operation(): for cfg in config.Config.all_configs(): url = cfg.db.url backend = url.get_backend_name() - host_conf = ( - backend, - url.username, url.host, url.database) + host_conf = (backend, url.username, url.host, url.database) if host_conf not in hosts: yield cfg @@ -128,14 +128,13 @@ def _follower_url_from_main(url, ident): @_update_db_opts.for_db("mssql") def _mssql_update_db_opts(db_url, db_opts): - db_opts['legacy_schema_aliasing'] = False - + db_opts["legacy_schema_aliasing"] = False @_follower_url_from_main.for_db("sqlite") def _sqlite_follower_url_from_main(url, ident): url = sa_url.make_url(url) - if not url.database or url.database == ':memory:': + if not url.database or url.database == ":memory:": return url else: return sa_url.make_url("sqlite:///%s.db" % ident) @@ -151,19 +150,20 @@ def _sqlite_post_configure_engine(url, engine, follower_ident): # as an attached if not follower_ident: dbapi_connection.execute( - 'ATTACH DATABASE "test_schema.db" AS test_schema') + 'ATTACH DATABASE "test_schema.db" AS test_schema' + ) else: dbapi_connection.execute( 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' - % follower_ident) + % follower_ident + ) @_create_db.for_db("postgresql") def _pg_create_db(cfg, eng, ident): template_db = cfg.options.postgresql_templatedb - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: try: _pg_drop_db(cfg, conn, ident) except Exception: @@ -175,7 +175,8 @@ def _pg_create_db(cfg, eng, ident): while True: try: conn.execute( - "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)) + "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db) + ) except exc.OperationalError as err: attempt += 1 if attempt >= 3: @@ -184,8 +185,11 @@ def _pg_create_db(cfg, eng, ident): log.info( "Waiting to create %s, URI %r, " "template DB %s is in use sleeping for .5", - ident, eng.url, template_db) - time.sleep(.5) + ident, + eng.url, + template_db, + ) + time.sleep(0.5) else: break @@ -203,9 +207,11 @@ def _mysql_create_db(cfg, eng, ident): # 1271, u"Illegal mix of collations for operation 'UNION'" conn.execute("CREATE DATABASE %s CHARACTER SET utf8mb3" % ident) conn.execute( - "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident) + "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb3" % ident + ) conn.execute( - "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident) + "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb3" % ident + ) @_configure_follower.for_db("mysql") @@ -221,14 +227,15 @@ def _sqlite_create_db(cfg, eng, ident): @_drop_db.for_db("postgresql") def _pg_drop_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: conn.execute( text( "select pg_terminate_backend(pid) from pg_stat_activity " "where usename=current_user and pid != pg_backend_pid() " "and datname=:dname" - ), dname=ident) + ), + dname=ident, + ) conn.execute("DROP DATABASE %s" % ident) @@ -257,11 +264,12 @@ def _oracle_create_db(cfg, eng, ident): conn.execute("create user %s identified by xe" % ident) conn.execute("create user %s_ts1 identified by xe" % ident) conn.execute("create user %s_ts2 identified by xe" % ident) - conn.execute("grant dba to %s" % (ident, )) + conn.execute("grant dba to %s" % (ident,)) conn.execute("grant unlimited tablespace to %s" % ident) conn.execute("grant unlimited tablespace to %s_ts1" % ident) conn.execute("grant unlimited tablespace to %s_ts2" % ident) + @_configure_follower.for_db("oracle") def _oracle_configure_follower(config, ident): config.test_schema = "%s_ts1" % ident @@ -320,6 +328,7 @@ def reap_dbs(idents_file): elif backend == "mssql": _reap_mssql_dbs(url, ident) + def _reap_oracle_dbs(url, idents): log.info("db reaper connecting to %r", url) eng = create_engine(url) @@ -330,8 +339,9 @@ def _reap_oracle_dbs(url, idents): to_reap = conn.execute( "select u.username from all_users u where username " "like 'TEST_%' and not exists (select username " - "from v$session where username=u.username)") - all_names = {username.lower() for (username, ) in to_reap} + "from v$session where username=u.username)" + ) + all_names = {username.lower() for (username,) in to_reap} to_drop = set() for name in all_names: if name.endswith("_ts1") or name.endswith("_ts2"): @@ -348,28 +358,28 @@ def _reap_oracle_dbs(url, idents): if _ora_drop_ignore(conn, username): dropped += 1 log.info( - "Dropped %d out of %d stale databases detected", - dropped, total) - + "Dropped %d out of %d stale databases detected", dropped, total + ) @_follower_url_from_main.for_db("oracle") def _oracle_follower_url_from_main(url, ident): url = sa_url.make_url(url) url.username = ident - url.password = 'xe' + url.password = "xe" return url @_create_db.for_db("mssql") def _mssql_create_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: conn.execute("create database %s" % ident) conn.execute( - "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident) + "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident + ) conn.execute( - "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident) + "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident + ) conn.execute("use %s" % ident) conn.execute("create schema test_schema") conn.execute("create schema test_schema_2") @@ -377,10 +387,10 @@ def _mssql_create_db(cfg, eng, ident): @_drop_db.for_db("mssql") def _mssql_drop_db(cfg, eng, ident): - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: _mssql_drop_ignore(conn, ident) + def _mssql_drop_ignore(conn, ident): try: # typically when this happens, we can't KILL the session anyway, @@ -401,8 +411,7 @@ def _mssql_drop_ignore(conn, ident): def _reap_mssql_dbs(url, idents): log.info("db reaper connecting to %r", url) eng = create_engine(url) - with eng.connect().execution_options( - isolation_level="AUTOCOMMIT") as conn: + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: log.info("identifiers in file: %s", ", ".join(idents)) @@ -410,8 +419,9 @@ def _reap_mssql_dbs(url, idents): "select d.name from sys.databases as d where name " "like 'TEST_%' and not exists (select session_id " "from sys.dm_exec_sessions " - "where database_id=d.database_id)") - all_names = {dbname.lower() for (dbname, ) in to_reap} + "where database_id=d.database_id)" + ) + all_names = {dbname.lower() for (dbname,) in to_reap} to_drop = set() for name in all_names: if name in idents: @@ -422,5 +432,5 @@ def _reap_mssql_dbs(url, idents): if _mssql_drop_ignore(conn, dbname): dropped += 1 log.info( - "Dropped %d out of %d stale databases detected", - dropped, total) + "Dropped %d out of %d stale databases detected", dropped, total + ) diff --git a/lib/sqlalchemy/testing/replay_fixture.py b/lib/sqlalchemy/testing/replay_fixture.py index b50f52e3d..9832b07a2 100644 --- a/lib/sqlalchemy/testing/replay_fixture.py +++ b/lib/sqlalchemy/testing/replay_fixture.py @@ -11,7 +11,6 @@ from sqlalchemy.orm import Session class ReplayFixtureTest(fixtures.TestBase): - @contextlib.contextmanager def _dummy_ctx(self, *arg, **kw): yield @@ -22,8 +21,8 @@ class ReplayFixtureTest(fixtures.TestBase): creator = config.db.pool._creator recorder = lambda: dbapi_session.recorder(creator()) engine = create_engine( - config.db.url, creator=recorder, - use_native_hstore=False) + config.db.url, creator=recorder, use_native_hstore=False + ) self.metadata = MetaData(engine) self.engine = engine self.session = Session(engine) @@ -37,8 +36,8 @@ class ReplayFixtureTest(fixtures.TestBase): player = lambda: dbapi_session.player() engine = create_engine( - config.db.url, creator=player, - use_native_hstore=False) + config.db.url, creator=player, use_native_hstore=False + ) self.metadata = MetaData(engine) self.engine = engine @@ -74,21 +73,49 @@ class ReplayableSession(object): NoAttribute = object() if util.py2k: - Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', 'UnboundMethodType',)]) + Natives = set( + [getattr(types, t) for t in dir(types) if not t.startswith("_")] + ).difference( + [ + getattr(types, t) + for t in ( + "FunctionType", + "BuiltinFunctionType", + "MethodType", + "BuiltinMethodType", + "LambdaType", + "UnboundMethodType", + ) + ] + ) else: - Natives = set([getattr(types, t) - for t in dir(types) if not t.startswith('_')]).\ - union([type(t) if not isinstance(t, type) - else t for t in __builtins__.values()]).\ - difference([getattr(types, t) - for t in ('FunctionType', 'BuiltinFunctionType', - 'MethodType', 'BuiltinMethodType', - 'LambdaType', )]) + Natives = ( + set( + [ + getattr(types, t) + for t in dir(types) + if not t.startswith("_") + ] + ) + .union( + [ + type(t) if not isinstance(t, type) else t + for t in __builtins__.values() + ] + ) + .difference( + [ + getattr(types, t) + for t in ( + "FunctionType", + "BuiltinFunctionType", + "MethodType", + "BuiltinMethodType", + "LambdaType", + ) + ] + ) + ) def __init__(self): self.buffer = deque() @@ -105,8 +132,10 @@ class ReplayableSession(object): self._subject = subject def __call__(self, *args, **kw): - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] + subject, buffer = [ + object.__getattribute__(self, x) + for x in ("_subject", "_buffer") + ] result = subject(*args, **kw) if type(result) not in ReplayableSession.Natives: @@ -126,8 +155,10 @@ class ReplayableSession(object): except AttributeError: pass - subject, buffer = [object.__getattribute__(self, x) - for x in ('_subject', '_buffer')] + subject, buffer = [ + object.__getattribute__(self, x) + for x in ("_subject", "_buffer") + ] try: result = type(subject).__getattribute__(subject, key) except AttributeError: @@ -146,7 +177,7 @@ class ReplayableSession(object): self._buffer = buffer def __call__(self, *args, **kw): - buffer = object.__getattribute__(self, '_buffer') + buffer = object.__getattribute__(self, "_buffer") result = buffer.popleft() if result is ReplayableSession.Callable: return self @@ -162,7 +193,7 @@ class ReplayableSession(object): return object.__getattribute__(self, key) except AttributeError: pass - buffer = object.__getattribute__(self, '_buffer') + buffer = object.__getattribute__(self, "_buffer") result = buffer.popleft() if result is ReplayableSession.Callable: return self diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 58df643f4..c96d26d32 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -26,7 +26,6 @@ class Requirements(object): class SuiteRequirements(Requirements): - @property def create_table(self): """target platform can emit basic CreateTable DDL.""" @@ -68,8 +67,8 @@ class SuiteRequirements(Requirements): # somehow only_if([x, y]) isn't working here, negation/conjunctions # getting confused. return exclusions.only_if( - lambda: self.on_update_cascade.enabled or - self.deferrable_fks.enabled + lambda: self.on_update_cascade.enabled + or self.deferrable_fks.enabled ) @property @@ -231,22 +230,21 @@ class SuiteRequirements(Requirements): def sane_rowcount(self): return exclusions.skip_if( lambda config: not config.db.dialect.supports_sane_rowcount, - "driver doesn't support 'sane' rowcount" + "driver doesn't support 'sane' rowcount", ) @property def sane_multi_rowcount(self): return exclusions.fails_if( lambda config: not config.db.dialect.supports_sane_multi_rowcount, - "driver %(driver)s %(doesnt_support)s 'sane' multi row count" + "driver %(driver)s %(doesnt_support)s 'sane' multi row count", ) @property def sane_rowcount_w_returning(self): return exclusions.fails_if( - lambda config: - not config.db.dialect.supports_sane_rowcount_returning, - "driver doesn't support 'sane' rowcount when returning is on" + lambda config: not config.db.dialect.supports_sane_rowcount_returning, + "driver doesn't support 'sane' rowcount when returning is on", ) @property @@ -255,9 +253,9 @@ class SuiteRequirements(Requirements): INSERT DEFAULT VALUES or equivalent.""" return exclusions.only_if( - lambda config: config.db.dialect.supports_empty_insert or - config.db.dialect.supports_default_values, - "empty inserts not supported" + lambda config: config.db.dialect.supports_empty_insert + or config.db.dialect.supports_default_values, + "empty inserts not supported", ) @property @@ -272,7 +270,7 @@ class SuiteRequirements(Requirements): return exclusions.only_if( lambda config: config.db.dialect.implicit_returning, - "%(database)s %(does_support)s 'returning'" + "%(database)s %(does_support)s 'returning'", ) @property @@ -297,7 +295,7 @@ class SuiteRequirements(Requirements): return exclusions.skip_if( lambda config: not config.db.dialect.requires_name_normalize, - "Backend does not require denormalized names." + "Backend does not require denormalized names.", ) @property @@ -307,7 +305,7 @@ class SuiteRequirements(Requirements): return exclusions.skip_if( lambda config: not config.db.dialect.supports_multivalues_insert, - "Backend does not support multirow inserts." + "Backend does not support multirow inserts.", ) @property @@ -355,27 +353,32 @@ class SuiteRequirements(Requirements): def server_side_cursors(self): """Target dialect must support server side cursors.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_server_side_cursors - ], "no server side cursors support") + return exclusions.only_if( + [lambda config: config.db.dialect.supports_server_side_cursors], + "no server side cursors support", + ) @property def sequences(self): """Target database must support SEQUENCEs.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences - ], "no sequence support") + return exclusions.only_if( + [lambda config: config.db.dialect.supports_sequences], + "no sequence support", + ) @property def sequences_optional(self): """Target database supports sequences, but also optionally as a means of generating new PK values.""" - return exclusions.only_if([ - lambda config: config.db.dialect.supports_sequences and - config.db.dialect.sequences_optional - ], "no sequence support, or sequences not optional") + return exclusions.only_if( + [ + lambda config: config.db.dialect.supports_sequences + and config.db.dialect.sequences_optional + ], + "no sequence support, or sequences not optional", + ) @property def reflects_pk_names(self): @@ -841,7 +844,8 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( - lambda config: config.options.low_connections) + lambda config: config.options.low_connections + ) @property def timing_intensive(self): @@ -859,37 +863,37 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( lambda config: util.py3k and config.options.has_coverage, - "Stability issues with coverage + py3k" + "Stability issues with coverage + py3k", ) @property def python2(self): return exclusions.skip_if( lambda: sys.version_info >= (3,), - "Python version 2.xx is required." + "Python version 2.xx is required.", ) @property def python3(self): return exclusions.skip_if( - lambda: sys.version_info < (3,), - "Python version 3.xx is required." + lambda: sys.version_info < (3,), "Python version 3.xx is required." ) @property def cpython(self): return exclusions.only_if( - lambda: util.cpython, - "cPython interpreter needed" + lambda: util.cpython, "cPython interpreter needed" ) @property def non_broken_pickle(self): from sqlalchemy.util import pickle + return exclusions.only_if( - lambda: not util.pypy and pickle.__name__ == 'cPickle' - or sys.version_info >= (3, 2), - "Needs cPickle+cPython or newer Python 3 pickle" + lambda: not util.pypy + and pickle.__name__ == "cPickle" + or sys.version_info >= (3, 2), + "Needs cPickle+cPython or newer Python 3 pickle", ) @property @@ -910,7 +914,7 @@ class SuiteRequirements(Requirements): """ return exclusions.skip_if( lambda config: config.options.has_coverage, - "Issues observed when coverage is enabled" + "Issues observed when coverage is enabled", ) def _has_mysql_on_windows(self, config): @@ -931,8 +935,9 @@ class SuiteRequirements(Requirements): def _has_sqlite(self): from sqlalchemy import create_engine + try: - create_engine('sqlite://') + create_engine("sqlite://") return True except ImportError: return False @@ -940,6 +945,7 @@ class SuiteRequirements(Requirements): def _has_cextensions(self): try: from sqlalchemy import cresultproxy, cprocessors + return True except ImportError: return False diff --git a/lib/sqlalchemy/testing/runner.py b/lib/sqlalchemy/testing/runner.py index 87be0749c..6aa820fd5 100644 --- a/lib/sqlalchemy/testing/runner.py +++ b/lib/sqlalchemy/testing/runner.py @@ -47,4 +47,4 @@ def setup_py_test(): to nose. """ - nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner']) + nose.main(addplugins=[NoseSQLAlchemy()], argv=["runner"]) diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 401c8cbb7..b345a9487 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -9,7 +9,7 @@ from . import exclusions from .. import schema, event from . import config -__all__ = 'Table', 'Column', +__all__ = "Table", "Column" table_options = {} @@ -17,30 +17,35 @@ table_options = {} def Table(*args, **kw): """A schema.Table wrapper/hook for dialect-specific tweaks.""" - test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')} + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} kw.update(table_options) - if exclusions.against(config._current, 'mysql'): - if 'mysql_engine' not in kw and 'mysql_type' not in kw and \ - "autoload_with" not in kw: - if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts: - kw['mysql_engine'] = 'InnoDB' + if exclusions.against(config._current, "mysql"): + if ( + "mysql_engine" not in kw + and "mysql_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mysql_engine"] = "InnoDB" else: - kw['mysql_engine'] = 'MyISAM' + kw["mysql_engine"] = "MyISAM" # Apply some default cascading rules for self-referential foreign keys. # MySQL InnoDB has some issues around seleting self-refs too. - if exclusions.against(config._current, 'firebird'): + if exclusions.against(config._current, "firebird"): table_name = args[0] - unpack = (config.db.dialect. - identifier_preparer.unformat_identifiers) + unpack = config.db.dialect.identifier_preparer.unformat_identifiers # Only going after ForeignKeys in Columns. May need to # expand to ForeignKeyConstraint too. - fks = [fk - for col in args if isinstance(col, schema.Column) - for fk in col.foreign_keys] + fks = [ + fk + for col in args + if isinstance(col, schema.Column) + for fk in col.foreign_keys + ] for fk in fks: # root around in raw spec @@ -54,9 +59,9 @@ def Table(*args, **kw): name = unpack(ref)[0] if name == table_name: if fk.ondelete is None: - fk.ondelete = 'CASCADE' + fk.ondelete = "CASCADE" if fk.onupdate is None: - fk.onupdate = 'CASCADE' + fk.onupdate = "CASCADE" return schema.Table(*args, **kw) @@ -64,37 +69,46 @@ def Table(*args, **kw): def Column(*args, **kw): """A schema.Column wrapper/hook for dialect-specific tweaks.""" - test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith('test_')} + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} if not config.requirements.foreign_key_ddl.enabled_for_config(config): args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] col = schema.Column(*args, **kw) - if test_opts.get('test_needs_autoincrement', False) and \ - kw.get('primary_key', False): + if test_opts.get("test_needs_autoincrement", False) and kw.get( + "primary_key", False + ): if col.default is None and col.server_default is None: col.autoincrement = True # allow any test suite to pick up on this - col.info['test_needs_autoincrement'] = True + col.info["test_needs_autoincrement"] = True # hardcoded rule for firebird, oracle; this should # be moved out - if exclusions.against(config._current, 'firebird', 'oracle'): + if exclusions.against(config._current, "firebird", "oracle"): + def add_seq(c, tbl): c._init_items( - schema.Sequence(_truncate_name( - config.db.dialect, tbl.name + '_' + c.name + '_seq'), - optional=True) + schema.Sequence( + _truncate_name( + config.db.dialect, tbl.name + "_" + c.name + "_seq" + ), + optional=True, + ) ) - event.listen(col, 'after_parent_attach', add_seq, propagate=True) + + event.listen(col, "after_parent_attach", add_seq, propagate=True) return col def _truncate_name(dialect, name): if len(name) > dialect.max_identifier_length: - return name[0:max(dialect.max_identifier_length - 6, 0)] + \ - "_" + hex(hash(name) % 64)[2:] + return ( + name[0 : max(dialect.max_identifier_length - 6, 0)] + + "_" + + hex(hash(name) % 64)[2:] + ) else: return name diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index 748d9722d..a4e142c5a 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -1,4 +1,3 @@ - from sqlalchemy.testing.suite.test_cte import * from sqlalchemy.testing.suite.test_dialect import * from sqlalchemy.testing.suite.test_ddl import * diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py index cc72278e6..d2f35933b 100644 --- a/lib/sqlalchemy/testing/suite/test_cte.py +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -10,22 +10,28 @@ from ..schema import Table, Column class CTETest(fixtures.TablesTest): __backend__ = True - __requires__ = 'ctes', + __requires__ = ("ctes",) - run_inserts = 'each' - run_deletes = 'each' + run_inserts = "each" + run_deletes = "each" @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column("parent_id", ForeignKey("some_table.id"))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id")), + ) - Table("some_other_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)), - Column("parent_id", Integer)) + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) @classmethod def insert_data(cls): @@ -36,28 +42,33 @@ class CTETest(fixtures.TablesTest): {"id": 2, "data": "d2", "parent_id": 1}, {"id": 3, "data": "d3", "parent_id": 1}, {"id": 4, "data": "d4", "parent_id": 3}, - {"id": 5, "data": "d5", "parent_id": 3} - ] + {"id": 5, "data": "d5", "parent_id": 3}, + ], ) def test_select_nonrecursive_round_trip(self): some_table = self.tables.some_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) result = conn.execute( select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"])) ) - eq_(result.fetchall(), [("d4", )]) + eq_(result.fetchall(), [("d4",)]) def test_select_recursive_round_trip(self): some_table = self.tables.some_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"])).cte( - "some_cte", recursive=True) + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) cte_alias = cte.alias("c1") st1 = some_table.alias() @@ -67,12 +78,13 @@ class CTETest(fixtures.TablesTest): select([st1]).where(st1.c.id == cte_alias.c.parent_id) ) result = conn.execute( - select([cte.c.data]).where( - cte.c.data != "d2").order_by(cte.c.data.desc()) + select([cte.c.data]) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) ) eq_( result.fetchall(), - [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)] + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], ) def test_insert_from_select_round_trip(self): @@ -80,20 +92,21 @@ class CTETest(fixtures.TablesTest): some_other_table = self.tables.some_other_table with config.db.connect() as conn: - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.insert().from_select( - ["id", "data", "parent_id"], - select([cte]) + ["id", "data", "parent_id"], select([cte]) ) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)] + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], ) @testing.requires.ctes_with_update_delete @@ -105,27 +118,31 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( - some_other_table.update().values(parent_id=5).where( - some_other_table.c.data == cte.c.data - ) + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), [ - (1, "d1", None), (2, "d2", 5), - (3, "d3", 5), (4, "d4", 5), (5, "d5", 3) - ] + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], ) @testing.requires.ctes_with_update_delete @@ -137,14 +154,15 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.delete().where( some_other_table.c.data == cte.c.data @@ -154,9 +172,7 @@ class CTETest(fixtures.TablesTest): conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [ - (1, "d1", None), (5, "d5", 3) - ] + [(1, "d1", None), (5, "d5", 3)], ) @testing.requires.ctes_with_update_delete @@ -168,26 +184,26 @@ class CTETest(fixtures.TablesTest): with config.db.connect() as conn: conn.execute( some_other_table.insert().from_select( - ['id', 'data', 'parent_id'], - select([some_table]) + ["id", "data", "parent_id"], select([some_table]) ) ) - cte = select([some_table]).where( - some_table.c.data.in_(["d2", "d3", "d4"]) - ).cte("some_cte") + cte = ( + select([some_table]) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) conn.execute( some_other_table.delete().where( - some_other_table.c.data == - select([cte.c.data]).where( - cte.c.id == some_other_table.c.id) + some_other_table.c.data + == select([cte.c.data]).where( + cte.c.id == some_other_table.c.id + ) ) ) eq_( conn.execute( select([some_other_table]).order_by(some_other_table.c.id) ).fetchall(), - [ - (1, "d1", None), (5, "d5", 3) - ] + [(1, "d1", None), (5, "d5", 3)], ) diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py index 1d8010c8a..7c44388d4 100644 --- a/lib/sqlalchemy/testing/suite/test_ddl.py +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -1,5 +1,3 @@ - - from .. import fixtures, config, util from ..config import requirements from ..assertions import eq_ @@ -11,55 +9,47 @@ class TableDDLTest(fixtures.TestBase): __backend__ = True def _simple_fixture(self): - return Table('test_table', self.metadata, - Column('id', Integer, primary_key=True, - autoincrement=False), - Column('data', String(50)) - ) + return Table( + "test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) def _underscore_fixture(self): - return Table('_test_table', self.metadata, - Column('id', Integer, primary_key=True, - autoincrement=False), - Column('_data', String(50)) - ) + return Table( + "_test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("_data", String(50)), + ) def _simple_roundtrip(self, table): with config.db.begin() as conn: - conn.execute(table.insert().values((1, 'some data'))) + conn.execute(table.insert().values((1, "some data"))) result = conn.execute(table.select()) - eq_( - result.first(), - (1, 'some data') - ) + eq_(result.first(), (1, "some data")) @requirements.create_table @util.provide_metadata def test_create_table(self): table = self._simple_fixture() - table.create( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) self._simple_roundtrip(table) @requirements.drop_table @util.provide_metadata def test_drop_table(self): table = self._simple_fixture() - table.create( - config.db, checkfirst=False - ) - table.drop( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) + table.drop(config.db, checkfirst=False) @requirements.create_table @util.provide_metadata def test_underscore_names(self): table = self._underscore_fixture() - table.create( - config.db, checkfirst=False - ) + table.create(config.db, checkfirst=False) self._simple_roundtrip(table) -__all__ = ('TableDDLTest', ) + +__all__ = ("TableDDLTest",) diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 2c5dd0e36..5e589f3b8 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -15,16 +15,19 @@ class ExceptionTest(fixtures.TablesTest): specific exceptions from real round trips, we need to be conservative. """ - run_deletes = 'each' + + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) @requirements.duplicate_key_raises_integrity_error def test_integrity_error(self): @@ -33,15 +36,14 @@ class ExceptionTest(fixtures.TablesTest): trans = conn.begin() conn.execute( - self.tables.manual_pk.insert(), - {'id': 1, 'data': 'd1'} + self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} ) assert_raises( exc.IntegrityError, conn.execute, self.tables.manual_pk.insert(), - {'id': 1, 'data': 'd1'} + {"id": 1, "data": "d1"}, ) trans.rollback() @@ -49,38 +51,39 @@ class ExceptionTest(fixtures.TablesTest): class AutocommitTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" - __requires__ = 'autocommit', + __requires__ = ("autocommit",) __backend__ = True @classmethod def define_tables(cls, metadata): - Table('some_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)), - test_needs_acid=True - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + test_needs_acid=True, + ) def _test_conn_autocommits(self, conn, autocommit): trans = conn.begin() conn.execute( - self.tables.some_table.insert(), - {"id": 1, "data": "some data"} + self.tables.some_table.insert(), {"id": 1, "data": "some data"} ) trans.rollback() eq_( conn.scalar(select([self.tables.some_table.c.id])), - 1 if autocommit else None + 1 if autocommit else None, ) conn.execute(self.tables.some_table.delete()) def test_autocommit_on(self): conn = config.db.connect() - c2 = conn.execution_options(isolation_level='AUTOCOMMIT') + c2 = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(c2, True) conn.invalidate() self._test_conn_autocommits(conn, False) @@ -98,7 +101,7 @@ class EscapingTest(fixtures.TestBase): """ m = self.metadata - t = Table('t', m, Column('data', String(50))) + t = Table("t", m, Column("data", String(50))) t.create(config.db) with config.db.begin() as conn: conn.execute(t.insert(), dict(data="some % value")) @@ -107,14 +110,17 @@ class EscapingTest(fixtures.TestBase): eq_( conn.scalar( select([t.c.data]).where( - t.c.data == literal_column("'some % value'")) + t.c.data == literal_column("'some % value'") + ) ), - "some % value" + "some % value", ) eq_( conn.scalar( select([t.c.data]).where( - t.c.data == literal_column("'some %% other value'")) - ), "some %% other value" + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", ) diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index c0b6b18eb..6257451eb 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -10,53 +10,48 @@ from ..schema import Table, Column class LastrowidTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True - __requires__ = 'implements_get_lastrowid', 'autoincrement_insert' + __requires__ = "implements_get_lastrowid", "autoincrement_insert" __engine_options__ = {"implicit_returning": False} @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (config.db.dialect.default_sequence_base, "some data") - ) + eq_(row, (config.db.dialect.default_sequence_base, "some data")) def test_autoincrement_on_insert(self): - config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.autoinc_pk.insert(), data="some data") self._assert_round_trip(self.tables.autoinc_pk, config.db) def test_last_inserted_id(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - r.inserted_primary_key, - [pk] - ) + eq_(r.inserted_primary_key, [pk]) # failed on pypy1.9 but seems to be OK on pypy 2.1 # @exclusions.fails_if(lambda: util.pypy, @@ -65,50 +60,57 @@ class LastrowidTest(fixtures.TablesTest): @requirements.dbapi_lastrowid def test_native_lastrowid_autoinc(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) lastrowid = r.lastrowid pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - lastrowid, pk - ) + eq_(lastrowid, pk) class InsertBehaviorTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) - Table('manual_pk', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('data', String(50)) - ) - Table('includes_defaults', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)), - Column('x', Integer, default=5), - Column('y', Integer, - default=literal_column("2", type_=Integer) + literal(2))) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + Table( + "includes_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("x", Integer, default=5), + Column( + "y", + Integer, + default=literal_column("2", type_=Integer) + literal(2), + ), + ) def test_autoclose_on_insert(self): if requirements.returning.enabled: engine = engines.testing_engine( - options={'implicit_returning': False}) + options={"implicit_returning": False} + ) else: engine = config.db - r = engine.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + r = engine.execute(self.tables.autoinc_pk.insert(), data="some data") assert r._soft_closed assert not r.closed assert r.is_insert @@ -117,8 +119,7 @@ class InsertBehaviorTest(fixtures.TablesTest): @requirements.returning def test_autoclose_on_insert_implicit_returning(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) assert r._soft_closed assert not r.closed @@ -127,15 +128,14 @@ class InsertBehaviorTest(fixtures.TablesTest): @requirements.empty_inserts def test_empty_insert(self): - r = config.db.execute( - self.tables.autoinc_pk.insert(), - ) + r = config.db.execute(self.tables.autoinc_pk.insert()) assert r._soft_closed assert not r.closed r = config.db.execute( - self.tables.autoinc_pk.select(). - where(self.tables.autoinc_pk.c.id != None) + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) ) assert len(r.fetchall()) @@ -150,15 +150,15 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) result = config.db.execute( - dest_table.insert(). - from_select( + dest_table.insert().from_select( ("data",), - select([src_table.c.data]). - where(src_table.c.data.in_(["data2", "data3"])) + select([src_table.c.data]).where( + src_table.c.data.in_(["data2", "data3"]) + ), ) ) @@ -167,7 +167,7 @@ class InsertBehaviorTest(fixtures.TablesTest): result = config.db.execute( select([dest_table.c.data]).order_by(dest_table.c.data) ) - eq_(result.fetchall(), [("data2", ), ("data3", )]) + eq_(result.fetchall(), [("data2",), ("data3",)]) @requirements.insert_from_select def test_insert_from_select_autoinc_no_rows(self): @@ -175,11 +175,11 @@ class InsertBehaviorTest(fixtures.TablesTest): dest_table = self.tables.autoinc_pk result = config.db.execute( - dest_table.insert(). - from_select( + dest_table.insert().from_select( ("data",), - select([src_table.c.data]). - where(src_table.c.data.in_(["data2", "data3"])) + select([src_table.c.data]).where( + src_table.c.data.in_(["data2", "data3"]) + ), ) ) eq_(result.inserted_primary_key, [None]) @@ -199,23 +199,23 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) config.db.execute( - table.insert(inline=True). - from_select(("id", "data",), - select([table.c.id + 5, table.c.data]). - where(table.c.data.in_(["data2", "data3"])) - ), + table.insert(inline=True).from_select( + ("id", "data"), + select([table.c.id + 5, table.c.data]).where( + table.c.data.in_(["data2", "data3"]) + ), + ) ) eq_( config.db.execute( select([table.c.data]).order_by(table.c.data) ).fetchall(), - [("data1", ), ("data2", ), ("data2", ), - ("data3", ), ("data3", )] + [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)], ) @requirements.insert_from_select @@ -227,56 +227,60 @@ class InsertBehaviorTest(fixtures.TablesTest): dict(id=1, data="data1"), dict(id=2, data="data2"), dict(id=3, data="data3"), - ] + ], ) config.db.execute( - table.insert(inline=True). - from_select(("id", "data",), - select([table.c.id + 5, table.c.data]). - where(table.c.data.in_(["data2", "data3"])) - ), + table.insert(inline=True).from_select( + ("id", "data"), + select([table.c.id + 5, table.c.data]).where( + table.c.data.in_(["data2", "data3"]) + ), + ) ) eq_( config.db.execute( select([table]).order_by(table.c.data, table.c.id) ).fetchall(), - [(1, 'data1', 5, 4), (2, 'data2', 5, 4), - (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)] + [ + (1, "data1", 5, 4), + (2, "data2", 5, 4), + (7, "data2", 5, 4), + (3, "data3", 5, 4), + (8, "data3", 5, 4), + ], ) class ReturningTest(fixtures.TablesTest): - run_create_tables = 'each' - __requires__ = 'returning', 'autoincrement_insert' + run_create_tables = "each" + __requires__ = "returning", "autoincrement_insert" __backend__ = True __engine_options__ = {"implicit_returning": True} def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (config.db.dialect.default_sequence_base, "some data") - ) + eq_(row, (config.db.dialect.default_sequence_base, "some data")) @classmethod def define_tables(cls, metadata): - Table('autoinc_pk', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('data', String(50)) - ) + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) @requirements.fetch_rows_post_commit def test_explicit_returning_pk_autocommit(self): engine = config.db table = self.tables.autoinc_pk r = engine.execute( - table.insert().returning( - table.c.id), - data="some data" + table.insert().returning(table.c.id), data="some data" ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) @@ -287,9 +291,7 @@ class ReturningTest(fixtures.TablesTest): table = self.tables.autoinc_pk with engine.begin() as conn: r = conn.execute( - table.insert().returning( - table.c.id), - data="some data" + table.insert().returning(table.c.id), data="some data" ) pk = r.first()[0] fetched_pk = config.db.scalar(select([table.c.id])) @@ -297,23 +299,16 @@ class ReturningTest(fixtures.TablesTest): def test_autoincrement_on_insert_implcit_returning(self): - config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.autoinc_pk.insert(), data="some data") self._assert_round_trip(self.tables.autoinc_pk, config.db) def test_last_inserted_id_implicit_returning(self): r = config.db.execute( - self.tables.autoinc_pk.insert(), - data="some data" + self.tables.autoinc_pk.insert(), data="some data" ) pk = config.db.scalar(select([self.tables.autoinc_pk.c.id])) - eq_( - r.inserted_primary_key, - [pk] - ) + eq_(r.inserted_primary_key, [pk]) -__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest') +__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 00a5aac01..bfed5f1ab 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1,5 +1,3 @@ - - import sqlalchemy as sa from sqlalchemy import exc as sa_exc from sqlalchemy import types as sql_types @@ -26,10 +24,12 @@ class HasTableTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('test_table', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) def test_has_table(self): with config.db.begin() as conn: @@ -46,8 +46,10 @@ class ComponentReflectionTest(fixtures.TablesTest): def setup_bind(cls): if config.requirements.independent_connections.enabled: from sqlalchemy import pool + return engines.testing_engine( - options=dict(poolclass=pool.StaticPool)) + options=dict(poolclass=pool.StaticPool) + ) else: return config.db @@ -65,86 +67,109 @@ class ComponentReflectionTest(fixtures.TablesTest): schema_prefix = "" if testing.requires.self_referential_foreign_keys.enabled: - users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - Column('parent_user_id', sa.Integer, - sa.ForeignKey('%susers.user_id' % - schema_prefix, - name='user_id_fk')), - schema=schema, - test_needs_fk=True, - ) + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + Column( + "parent_user_id", + sa.Integer, + sa.ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ), + schema=schema, + test_needs_fk=True, + ) else: - users = Table('users', metadata, - Column('user_id', sa.INT, primary_key=True), - Column('test1', sa.CHAR(5), nullable=False), - Column('test2', sa.Float(5), nullable=False), - schema=schema, - test_needs_fk=True, - ) - - Table("dingalings", metadata, - Column('dingaling_id', sa.Integer, primary_key=True), - Column('address_id', sa.Integer, - sa.ForeignKey('%semail_addresses.address_id' % - schema_prefix)), - Column('data', sa.String(30)), - schema=schema, - test_needs_fk=True, - ) - Table('email_addresses', metadata, - Column('address_id', sa.Integer), - Column('remote_user_id', sa.Integer, - sa.ForeignKey(users.c.user_id)), - Column('email_address', sa.String(20)), - sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'), - schema=schema, - test_needs_fk=True, - ) - Table('comment_test', metadata, - Column('id', sa.Integer, primary_key=True, comment='id comment'), - Column('data', sa.String(20), comment='data % comment'), - Column( - 'd2', sa.String(20), - comment=r"""Comment types type speedily ' " \ '' Fun!"""), - schema=schema, - comment=r"""the test % ' " \ table comment""") + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + ), + Column("data", sa.String(30)), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sa.Integer), + Column( + "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) + ), + Column("email_address", sa.String(20)), + sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sa.Integer, primary_key=True, comment="id comment"), + Column("data", sa.String(20), comment="data % comment"), + Column( + "d2", + sa.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) if testing.requires.cross_schema_fk_reflection.enabled: if schema is None: Table( - 'local_table', metadata, - Column('id', sa.Integer, primary_key=True), - Column('data', sa.String(20)), + "local_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), Column( - 'remote_id', + "remote_id", ForeignKey( - '%s.remote_table_2.id' % - testing.config.test_schema) + "%s.remote_table_2.id" % testing.config.test_schema + ), ), test_needs_fk=True, - schema=config.db.dialect.default_schema_name + schema=config.db.dialect.default_schema_name, ) else: Table( - 'remote_table', metadata, - Column('id', sa.Integer, primary_key=True), + "remote_table", + metadata, + Column("id", sa.Integer, primary_key=True), Column( - 'local_id', + "local_id", ForeignKey( - '%s.local_table.id' % - config.db.dialect.default_schema_name) + "%s.local_table.id" + % config.db.dialect.default_schema_name + ), ), - Column('data', sa.String(20)), + Column("data", sa.String(20)), schema=schema, test_needs_fk=True, ) Table( - 'remote_table_2', metadata, - Column('id', sa.Integer, primary_key=True), - Column('data', sa.String(20)), + "remote_table_2", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), schema=schema, test_needs_fk=True, ) @@ -155,19 +180,21 @@ class ComponentReflectionTest(fixtures.TablesTest): if not schema: # test_needs_fk is at the moment to force MySQL InnoDB noncol_idx_test_nopk = Table( - 'noncol_idx_test_nopk', metadata, - Column('q', sa.String(5)), + "noncol_idx_test_nopk", + metadata, + Column("q", sa.String(5)), test_needs_fk=True, ) noncol_idx_test_pk = Table( - 'noncol_idx_test_pk', metadata, - Column('id', sa.Integer, primary_key=True), - Column('q', sa.String(5)), + "noncol_idx_test_pk", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("q", sa.String(5)), test_needs_fk=True, ) - Index('noncol_idx_nopk', noncol_idx_test_nopk.c.q.desc()) - Index('noncol_idx_pk', noncol_idx_test_pk.c.q.desc()) + Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) if testing.requires.view_column_reflection.enabled: cls.define_views(metadata, schema) @@ -180,34 +207,35 @@ class ComponentReflectionTest(fixtures.TablesTest): # temp table fixture if testing.against("oracle"): kw = { - 'prefixes': ["GLOBAL TEMPORARY"], - 'oracle_on_commit': 'PRESERVE ROWS' + "prefixes": ["GLOBAL TEMPORARY"], + "oracle_on_commit": "PRESERVE ROWS", } else: - kw = { - 'prefixes': ["TEMPORARY"], - } + kw = {"prefixes": ["TEMPORARY"]} user_tmp = Table( - "user_tmp", metadata, + "user_tmp", + metadata, Column("id", sa.INT, primary_key=True), - Column('name', sa.VARCHAR(50)), - Column('foo', sa.INT), - sa.UniqueConstraint('name', name='user_tmp_uq'), + Column("name", sa.VARCHAR(50)), + Column("foo", sa.INT), + sa.UniqueConstraint("name", name="user_tmp_uq"), sa.Index("user_tmp_ix", "foo"), **kw ) - if testing.requires.view_reflection.enabled and \ - testing.requires.temporary_views.enabled: - event.listen( - user_tmp, "after_create", - DDL("create temporary view user_tmp_v as " - "select * from user_tmp") - ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): event.listen( - user_tmp, "before_drop", - DDL("drop view user_tmp_v") + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp" + ), ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) @classmethod def define_index(cls, metadata, users): @@ -216,23 +244,19 @@ class ComponentReflectionTest(fixtures.TablesTest): @classmethod def define_views(cls, metadata, schema): - for table_name in ('users', 'email_addresses'): + for table_name in ("users", "email_addresses"): fullname = table_name if schema: fullname = "%s.%s" % (schema, table_name) - view_name = fullname + '_v' + view_name = fullname + "_v" query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, fullname) - - event.listen( - metadata, - "after_create", - DDL(query) + view_name, + fullname, ) + + event.listen(metadata, "after_create", DDL(query)) event.listen( - metadata, - "before_drop", - DDL("DROP VIEW %s" % view_name) + metadata, "before_drop", DDL("DROP VIEW %s" % view_name) ) @testing.requires.schema_reflection @@ -244,9 +268,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schema_reflection def test_dialect_initialize(self): engine = engines.testing_engine() - assert not hasattr(engine.dialect, 'default_schema_name') + assert not hasattr(engine.dialect, "default_schema_name") inspect(engine) - assert hasattr(engine.dialect, 'default_schema_name') + assert hasattr(engine.dialect, "default_schema_name") @testing.requires.schema_reflection def test_get_default_schema_name(self): @@ -254,40 +278,49 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(insp.default_schema_name, testing.db.dialect.default_schema_name) @testing.provide_metadata - def _test_get_table_names(self, schema=None, table_type='table', - order_by=None): + def _test_get_table_names( + self, schema=None, table_type="table", order_by=None + ): _ignore_tables = [ - 'comment_test', 'noncol_idx_test_pk', 'noncol_idx_test_nopk', - 'local_table', 'remote_table', 'remote_table_2' + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", ] meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) - if table_type == 'view': + if table_type == "view": table_names = insp.get_view_names(schema) table_names.sort() - answer = ['email_addresses_v', 'users_v'] + answer = ["email_addresses_v", "users_v"] eq_(sorted(table_names), answer) else: table_names = [ - t for t in insp.get_table_names( - schema, - order_by=order_by) if t not in _ignore_tables] + t + for t in insp.get_table_names(schema, order_by=order_by) + if t not in _ignore_tables + ] - if order_by == 'foreign_key': - answer = ['users', 'email_addresses', 'dingalings'] + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] eq_(table_names, answer) else: - answer = ['dingalings', 'email_addresses', 'users'] + answer = ["dingalings", "email_addresses", "users"] eq_(sorted(table_names), answer) @testing.requires.temp_table_names def test_get_temp_table_names(self): insp = inspect(self.bind) temp_table_names = insp.get_temp_table_names() - eq_(sorted(temp_table_names), ['user_tmp']) + eq_(sorted(temp_table_names), ["user_tmp"]) @testing.requires.view_reflection @testing.requires.temp_table_names @@ -295,7 +328,7 @@ class ComponentReflectionTest(fixtures.TablesTest): def test_get_temp_view_names(self): insp = inspect(self.bind) temp_table_names = insp.get_temp_view_names() - eq_(sorted(temp_table_names), ['user_tmp_v']) + eq_(sorted(temp_table_names), ["user_tmp_v"]) @testing.requires.table_reflection def test_get_table_names(self): @@ -304,7 +337,7 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.table_reflection @testing.requires.foreign_key_constraint_reflection def test_get_table_names_fks(self): - self._test_get_table_names(order_by='foreign_key') + self._test_get_table_names(order_by="foreign_key") @testing.requires.comment_reflection def test_get_comments(self): @@ -320,26 +353,24 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_( insp.get_table_comment("comment_test", schema=schema), - {"text": r"""the test % ' " \ table comment"""} + {"text": r"""the test % ' " \ table comment"""}, ) - eq_( - insp.get_table_comment("users", schema=schema), - {"text": None} - ) + eq_(insp.get_table_comment("users", schema=schema), {"text": None}) eq_( [ - {"name": rec['name'], "comment": rec['comment']} - for rec in - insp.get_columns("comment_test", schema=schema) + {"name": rec["name"], "comment": rec["comment"]} + for rec in insp.get_columns("comment_test", schema=schema) ], [ - {'comment': 'id comment', 'name': 'id'}, - {'comment': 'data % comment', 'name': 'data'}, - {'comment': r"""Comment types type speedily ' " \ '' Fun!""", - 'name': 'd2'} - ] + {"comment": "id comment", "name": "id"}, + {"comment": "data % comment", "name": "data"}, + { + "comment": r"""Comment types type speedily ' " \ '' Fun!""", + "name": "d2", + }, + ], ) @testing.requires.table_reflection @@ -349,30 +380,33 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.view_column_reflection def test_get_view_names(self): - self._test_get_table_names(table_type='view') + self._test_get_table_names(table_type="view") @testing.requires.view_column_reflection @testing.requires.schemas def test_get_view_names_with_schema(self): self._test_get_table_names( - testing.config.test_schema, table_type='view') + testing.config.test_schema, table_type="view" + ) @testing.requires.table_reflection @testing.requires.view_column_reflection def test_get_tables_and_views(self): self._test_get_table_names() - self._test_get_table_names(table_type='view') + self._test_get_table_names(table_type="view") - def _test_get_columns(self, schema=None, table_type='table'): + def _test_get_columns(self, schema=None, table_type="table"): meta = MetaData(testing.db) - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings - table_names = ['users', 'email_addresses'] - if table_type == 'view': - table_names = ['users_v', 'email_addresses_v'] + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) + table_names = ["users", "email_addresses"] + if table_type == "view": + table_names = ["users_v", "email_addresses_v"] insp = inspect(meta.bind) - for table_name, table in zip(table_names, (users, - addresses)): + for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) self.assert_(len(cols) > 0, len(cols)) @@ -380,36 +414,46 @@ class ComponentReflectionTest(fixtures.TablesTest): # should be in order for i, col in enumerate(table.columns): - eq_(col.name, cols[i]['name']) - ctype = cols[i]['type'].__class__ + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ ctype_def = col.type if isinstance(ctype_def, sa.types.TypeEngine): ctype_def = ctype_def.__class__ # Oracle returns Date for DateTime. - if testing.against('oracle') and ctype_def \ - in (sql_types.Date, sql_types.DateTime): + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): ctype_def = sql_types.Date # assert that the desired type and return type share # a base within one of the generic types. - self.assert_(len(set(ctype.__mro__). - intersection(ctype_def.__mro__). - intersection([ - sql_types.Integer, - sql_types.Numeric, - sql_types.DateTime, - sql_types.Date, - sql_types.Time, - sql_types.String, - sql_types._Binary, - ])) > 0, '%s(%s), %s(%s)' % - (col.name, col.type, cols[i]['name'], ctype)) + self.assert_( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" + % (col.name, col.type, cols[i]["name"], ctype), + ) if not col.primary_key: - assert cols[i]['default'] is None + assert cols[i]["default"] is None @testing.requires.table_reflection def test_get_columns(self): @@ -417,24 +461,20 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _type_round_trip(self, *types): - t = Table('t', self.metadata, - *[ - Column('t%d' % i, type_) - for i, type_ in enumerate(types) - ] - ) + t = Table( + "t", + self.metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) t.create() return [ - c['type'] for c in - inspect(self.metadata.bind).get_columns('t') + c["type"] for c in inspect(self.metadata.bind).get_columns("t") ] @testing.requires.table_reflection def test_numeric_reflection(self): - for typ in self._type_round_trip( - sql_types.Numeric(18, 5), - ): + for typ in self._type_round_trip(sql_types.Numeric(18, 5)): assert isinstance(typ, sql_types.Numeric) eq_(typ.precision, 18) eq_(typ.scale, 5) @@ -448,16 +488,19 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.table_reflection @testing.provide_metadata def test_nullable_reflection(self): - t = Table('t', self.metadata, - Column('a', Integer, nullable=True), - Column('b', Integer, nullable=False)) + t = Table( + "t", + self.metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) t.create() eq_( dict( - (col['name'], col['nullable']) - for col in inspect(self.metadata.bind).get_columns('t') + (col["name"], col["nullable"]) + for col in inspect(self.metadata.bind).get_columns("t") ), - {"a": True, "b": False} + {"a": True, "b": False}, ) @testing.requires.table_reflection @@ -470,32 +513,30 @@ class ComponentReflectionTest(fixtures.TablesTest): meta = MetaData(self.bind) user_tmp = self.tables.user_tmp insp = inspect(meta.bind) - cols = insp.get_columns('user_tmp') + cols = insp.get_columns("user_tmp") self.assert_(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): - eq_(col.name, cols[i]['name']) + eq_(col.name, cols[i]["name"]) @testing.requires.temp_table_reflection @testing.requires.view_column_reflection @testing.requires.temporary_views def test_get_temp_view_columns(self): insp = inspect(self.bind) - cols = insp.get_columns('user_tmp_v') - eq_( - [col['name'] for col in cols], - ['id', 'name', 'foo'] - ) + cols = insp.get_columns("user_tmp_v") + eq_([col["name"] for col in cols], ["id", "name", "foo"]) @testing.requires.view_column_reflection def test_get_view_columns(self): - self._test_get_columns(table_type='view') + self._test_get_columns(table_type="view") @testing.requires.view_column_reflection @testing.requires.schemas def test_get_view_columns_with_schema(self): self._test_get_columns( - schema=testing.config.test_schema, table_type='view') + schema=testing.config.test_schema, table_type="view" + ) @testing.provide_metadata def _test_get_pk_constraint(self, schema=None): @@ -504,15 +545,15 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(meta.bind) users_cons = insp.get_pk_constraint(users.name, schema=schema) - users_pkeys = users_cons['constrained_columns'] - eq_(users_pkeys, ['user_id']) + users_pkeys = users_cons["constrained_columns"] + eq_(users_pkeys, ["user_id"]) addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) - addr_pkeys = addr_cons['constrained_columns'] - eq_(addr_pkeys, ['address_id']) + addr_pkeys = addr_cons["constrained_columns"] + eq_(addr_pkeys, ["address_id"]) with testing.requires.reflects_pk_names.fail_if(): - eq_(addr_cons['name'], 'email_ad_pk') + eq_(addr_cons["name"], "email_ad_pk") @testing.requires.primary_key_constraint_reflection def test_get_pk_constraint(self): @@ -534,44 +575,46 @@ class ComponentReflectionTest(fixtures.TablesTest): sa_exc.SADeprecationWarning, "Call to deprecated method get_primary_keys." " Use get_pk_constraint instead.", - insp.get_primary_keys, users.name + insp.get_primary_keys, + users.name, ) @testing.provide_metadata def _test_get_foreign_keys(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) expected_schema = schema # users if testing.requires.self_referential_foreign_keys.enabled: - users_fkeys = insp.get_foreign_keys(users.name, - schema=schema) + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) fkey1 = users_fkeys[0] with testing.requires.named_constraints.fail_if(): - eq_(fkey1['name'], "user_id_fk") + eq_(fkey1["name"], "user_id_fk") - eq_(fkey1['referred_schema'], expected_schema) - eq_(fkey1['referred_table'], users.name) - eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) if testing.requires.self_referential_foreign_keys.enabled: - eq_(fkey1['constrained_columns'], ['parent_user_id']) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) # addresses - addr_fkeys = insp.get_foreign_keys(addresses.name, - schema=schema) + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] with testing.requires.implicitly_named_constraints.fail_if(): - self.assert_(fkey1['name'] is not None) + self.assert_(fkey1["name"] is not None) - eq_(fkey1['referred_schema'], expected_schema) - eq_(fkey1['referred_table'], users.name) - eq_(fkey1['referred_columns'], ['user_id', ]) - eq_(fkey1['constrained_columns'], ['remote_user_id']) + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) @testing.requires.foreign_key_constraint_reflection def test_get_foreign_keys(self): @@ -586,9 +629,9 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.schemas def test_get_inter_schema_foreign_keys(self): local_table, remote_table, remote_table_2 = self.tables( - '%s.local_table' % testing.db.dialect.default_schema_name, - '%s.remote_table' % testing.config.test_schema, - '%s.remote_table_2' % testing.config.test_schema + "%s.local_table" % testing.db.dialect.default_schema_name, + "%s.remote_table" % testing.config.test_schema, + "%s.remote_table_2" % testing.config.test_schema, ) insp = inspect(config.db) @@ -597,25 +640,25 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(len(local_fkeys), 1) fkey1 = local_fkeys[0] - eq_(fkey1['referred_schema'], testing.config.test_schema) - eq_(fkey1['referred_table'], remote_table_2.name) - eq_(fkey1['referred_columns'], ['id', ]) - eq_(fkey1['constrained_columns'], ['remote_id']) + eq_(fkey1["referred_schema"], testing.config.test_schema) + eq_(fkey1["referred_table"], remote_table_2.name) + eq_(fkey1["referred_columns"], ["id"]) + eq_(fkey1["constrained_columns"], ["remote_id"]) remote_fkeys = insp.get_foreign_keys( - remote_table.name, schema=testing.config.test_schema) + remote_table.name, schema=testing.config.test_schema + ) eq_(len(remote_fkeys), 1) fkey2 = remote_fkeys[0] - assert fkey2['referred_schema'] in ( + assert fkey2["referred_schema"] in ( None, - testing.db.dialect.default_schema_name + testing.db.dialect.default_schema_name, ) - eq_(fkey2['referred_table'], local_table.name) - eq_(fkey2['referred_columns'], ['id', ]) - eq_(fkey2['constrained_columns'], ['local_id']) - + eq_(fkey2["referred_table"], local_table.name) + eq_(fkey2["referred_columns"], ["id"]) + eq_(fkey2["constrained_columns"], ["local_id"]) @testing.requires.foreign_key_constraint_option_reflection_ondelete def test_get_foreign_key_options_ondelete(self): @@ -630,26 +673,32 @@ class ComponentReflectionTest(fixtures.TablesTest): meta = self.metadata Table( - 'x', meta, - Column('id', Integer, primary_key=True), - test_needs_fk=True - ) - - Table('table', meta, - Column('id', Integer, primary_key=True), - Column('x_id', Integer, sa.ForeignKey('x.id', name='xid')), - Column('test', String(10)), - test_needs_fk=True) - - Table('user', meta, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('tid', Integer), - sa.ForeignKeyConstraint( - ['tid'], ['table.id'], - name='myfk', - **options), - test_needs_fk=True) + "x", + meta, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) + + Table( + "table", + meta, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + meta, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) meta.create_all() @@ -657,49 +706,44 @@ class ComponentReflectionTest(fixtures.TablesTest): # test 'options' is always present for a backend # that can reflect these, since alembic looks for this - opts = insp.get_foreign_keys('table')[0]['options'] + opts = insp.get_foreign_keys("table")[0]["options"] - eq_( - dict( - (k, opts[k]) - for k in opts if opts[k] - ), - {} - ) + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) - opts = insp.get_foreign_keys('user')[0]['options'] - eq_( - dict( - (k, opts[k]) - for k in opts if opts[k] - ), - options - ) + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(dict((k, opts[k]) for k in opts if opts[k]), options) def _assert_insp_indexes(self, indexes, expected_indexes): - index_names = [d['name'] for d in indexes] + index_names = [d["name"] for d in indexes] for e_index in expected_indexes: - assert e_index['name'] in index_names - index = indexes[index_names.index(e_index['name'])] + assert e_index["name"] in index_names + index = indexes[index_names.index(e_index["name"])] for key in e_index: eq_(e_index[key], index[key]) @testing.provide_metadata def _test_get_indexes(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. insp = inspect(meta.bind) - indexes = insp.get_indexes('users', schema=schema) + indexes = insp.get_indexes("users", schema=schema) expected_indexes = [ - {'unique': False, - 'column_names': ['test1', 'test2'], - 'name': 'users_t_idx'}, - {'unique': False, - 'column_names': ['user_id', 'test2', 'test1'], - 'name': 'users_all_idx'} + { + "unique": False, + "column_names": ["test1", "test2"], + "name": "users_t_idx", + }, + { + "unique": False, + "column_names": ["user_id", "test2", "test1"], + "name": "users_all_idx", + }, ] self._assert_insp_indexes(indexes, expected_indexes) @@ -721,10 +765,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # reflecting an index that has "x DESC" in it as the column. # the DB may or may not give us "x", but make sure we get the index # back, it has a name, it's connected to the table. - expected_indexes = [ - {'unique': False, - 'name': ixname} - ] + expected_indexes = [{"unique": False, "name": ixname}] self._assert_insp_indexes(indexes, expected_indexes) t = Table(tname, meta, autoload_with=meta.bind) @@ -748,24 +789,30 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.unique_constraint_reflection def test_get_temp_table_unique_constraints(self): insp = inspect(self.bind) - reflected = insp.get_unique_constraints('user_tmp') + reflected = insp.get_unique_constraints("user_tmp") for refl in reflected: # Different dialects handle duplicate index and constraints # differently, so ignore this flag - refl.pop('duplicates_index', None) - eq_(reflected, [{'column_names': ['name'], 'name': 'user_tmp_uq'}]) + refl.pop("duplicates_index", None) + eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}]) @testing.requires.temp_table_reflection def test_get_temp_table_indexes(self): insp = inspect(self.bind) - indexes = insp.get_indexes('user_tmp') + indexes = insp.get_indexes("user_tmp") for ind in indexes: - ind.pop('dialect_options', None) + ind.pop("dialect_options", None) eq_( # TODO: we need to add better filtering for indexes/uq constraints # that are doubled up - [idx for idx in indexes if idx['name'] == 'user_tmp_ix'], - [{'unique': False, 'column_names': ['foo'], 'name': 'user_tmp_ix'}] + [idx for idx in indexes if idx["name"] == "user_tmp_ix"], + [ + { + "unique": False, + "column_names": ["foo"], + "name": "user_tmp_ix", + } + ], ) @testing.requires.unique_constraint_reflection @@ -783,36 +830,37 @@ class ComponentReflectionTest(fixtures.TablesTest): # CREATE TABLE? uniques = sorted( [ - {'name': 'unique_a', 'column_names': ['a']}, - {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']}, - {'name': 'unique_c_a_b', 'column_names': ['c', 'a', 'b']}, - {'name': 'unique_asc_key', 'column_names': ['asc', 'key']}, - {'name': 'i.have.dots', 'column_names': ['b']}, - {'name': 'i have spaces', 'column_names': ['c']}, + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, ], - key=operator.itemgetter('name') + key=operator.itemgetter("name"), ) orig_meta = self.metadata table = Table( - 'testtbl', orig_meta, - Column('a', sa.String(20)), - Column('b', sa.String(30)), - Column('c', sa.Integer), + "testtbl", + orig_meta, + Column("a", sa.String(20)), + Column("b", sa.String(30)), + Column("c", sa.Integer), # reserved identifiers - Column('asc', sa.String(30)), - Column('key', sa.String(30)), - schema=schema + Column("asc", sa.String(30)), + Column("key", sa.String(30)), + schema=schema, ) for uc in uniques: table.append_constraint( - sa.UniqueConstraint(*uc['column_names'], name=uc['name']) + sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) ) orig_meta.create_all() inspector = inspect(orig_meta.bind) reflected = sorted( - inspector.get_unique_constraints('testtbl', schema=schema), - key=operator.itemgetter('name') + inspector.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), ) names_that_duplicate_index = set() @@ -820,25 +868,31 @@ class ComponentReflectionTest(fixtures.TablesTest): for orig, refl in zip(uniques, reflected): # Different dialects handle duplicate index and constraints # differently, so ignore this flag - dupe = refl.pop('duplicates_index', None) + dupe = refl.pop("duplicates_index", None) if dupe: names_that_duplicate_index.add(dupe) eq_(orig, refl) reflected_metadata = MetaData() reflected = Table( - 'testtbl', reflected_metadata, autoload_with=orig_meta.bind, - schema=schema) + "testtbl", + reflected_metadata, + autoload_with=orig_meta.bind, + schema=schema, + ) # test "deduplicates for index" logic. MySQL and Oracle # "unique constraints" are actually unique indexes (with possible # exception of a unique that is a dupe of another one in the case # of Oracle). make sure # they aren't duplicated. idx_names = set([idx.name for idx in reflected.indexes]) - uq_names = set([ - uq.name for uq in reflected.constraints - if isinstance(uq, sa.UniqueConstraint)]).difference( - ['unique_c_a_b']) + uq_names = set( + [ + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + ] + ).difference(["unique_c_a_b"]) assert not idx_names.intersection(uq_names) if names_that_duplicate_index: @@ -858,47 +912,52 @@ class ComponentReflectionTest(fixtures.TablesTest): def _test_get_check_constraints(self, schema=None): orig_meta = self.metadata Table( - 'sa_cc', orig_meta, - Column('a', Integer()), - sa.CheckConstraint('a > 1 AND a < 5', name='cc1'), - sa.CheckConstraint('a = 1 OR (a > 2 AND a < 5)', name='cc2'), - schema=schema + "sa_cc", + orig_meta, + Column("a", Integer()), + sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), + sa.CheckConstraint("a = 1 OR (a > 2 AND a < 5)", name="cc2"), + schema=schema, ) orig_meta.create_all() inspector = inspect(orig_meta.bind) reflected = sorted( - inspector.get_check_constraints('sa_cc', schema=schema), - key=operator.itemgetter('name') + inspector.get_check_constraints("sa_cc", schema=schema), + key=operator.itemgetter("name"), ) # trying to minimize effect of quoting, parenthesis, etc. # may need to add more to this as new dialects get CHECK # constraint reflection support def normalize(sqltext): - return " ".join(re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)) + return " ".join( + re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + ) reflected = [ - {"name": item["name"], - "sqltext": normalize(item["sqltext"])} + {"name": item["name"], "sqltext": normalize(item["sqltext"])} for item in reflected ] eq_( reflected, [ - {'name': 'cc1', 'sqltext': 'a > 1 and a < 5'}, - {'name': 'cc2', 'sqltext': 'a = 1 or a > 2 and a < 5'} - ] + {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + {"name": "cc2", "sqltext": "a = 1 or a > 2 and a < 5"}, + ], ) @testing.provide_metadata def _test_get_view_definition(self, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings - view_name1 = 'users_v' - view_name2 = 'email_addresses_v' + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) + view_name1 = "users_v" + view_name2 = "email_addresses_v" insp = inspect(meta.bind) v1 = insp.get_view_definition(view_name1, schema=schema) self.assert_(v1) @@ -918,18 +977,21 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.provide_metadata def _test_get_table_oid(self, table_name, schema=None): meta = self.metadata - users, addresses, dingalings = self.tables.users, \ - self.tables.email_addresses, self.tables.dingalings + users, addresses, dingalings = ( + self.tables.users, + self.tables.email_addresses, + self.tables.dingalings, + ) insp = inspect(meta.bind) oid = insp.get_table_oid(table_name, schema) self.assert_(isinstance(oid, int)) def test_get_table_oid(self): - self._test_get_table_oid('users') + self._test_get_table_oid("users") @testing.requires.schemas def test_get_table_oid_with_schema(self): - self._test_get_table_oid('users', schema=testing.config.test_schema) + self._test_get_table_oid("users", schema=testing.config.test_schema) @testing.requires.table_reflection @testing.provide_metadata @@ -950,49 +1012,53 @@ class ComponentReflectionTest(fixtures.TablesTest): insp = inspect(meta.bind) for tname, cname in [ - ('users', 'user_id'), - ('email_addresses', 'address_id'), - ('dingalings', 'dingaling_id'), + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), ]: cols = insp.get_columns(tname) - id_ = {c['name']: c for c in cols}[cname] - assert id_.get('autoincrement', True) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) class NormalizedNameTest(fixtures.TablesTest): - __requires__ = 'denormalized_names', + __requires__ = ("denormalized_names",) __backend__ = True @classmethod def define_tables(cls, metadata): Table( - quoted_name('t1', quote=True), metadata, - Column('id', Integer, primary_key=True), + quoted_name("t1", quote=True), + metadata, + Column("id", Integer, primary_key=True), ) Table( - quoted_name('t2', quote=True), metadata, - Column('id', Integer, primary_key=True), - Column('t1id', ForeignKey('t1.id')) + quoted_name("t2", quote=True), + metadata, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), ) def test_reflect_lowercase_forced_tables(self): m2 = MetaData(testing.db) - t2_ref = Table(quoted_name('t2', quote=True), m2, autoload=True) - t1_ref = m2.tables['t1'] + t2_ref = Table(quoted_name("t2", quote=True), m2, autoload=True) + t1_ref = m2.tables["t1"] assert t2_ref.c.t1id.references(t1_ref.c.id) m3 = MetaData(testing.db) - m3.reflect(only=lambda name, m: name.lower() in ('t1', 't2')) - assert m3.tables['t2'].c.t1id.references(m3.tables['t1'].c.id) + m3.reflect(only=lambda name, m: name.lower() in ("t1", "t2")) + assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) def test_get_table_names(self): tablenames = [ - t for t in inspect(testing.db).get_table_names() - if t.lower() in ("t1", "t2")] + t + for t in inspect(testing.db).get_table_names() + if t.lower() in ("t1", "t2") + ] eq_(tablenames[0].upper(), tablenames[0].lower()) eq_(tablenames[1].upper(), tablenames[1].lower()) -__all__ = ('ComponentReflectionTest', 'HasTableTest', 'NormalizedNameTest') +__all__ = ("ComponentReflectionTest", "HasTableTest", "NormalizedNameTest") diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index f464d47eb..247f05cf5 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -15,14 +15,18 @@ class RowFetchTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) - Table('has_dates', metadata, - Column('id', Integer, primary_key=True), - Column('today', DateTime) - ) + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "has_dates", + metadata, + Column("id", Integer, primary_key=True), + Column("today", DateTime), + ) @classmethod def insert_data(cls): @@ -32,65 +36,51 @@ class RowFetchTest(fixtures.TablesTest): {"id": 1, "data": "d1"}, {"id": 2, "data": "d2"}, {"id": 3, "data": "d3"}, - ] + ], ) config.db.execute( cls.tables.has_dates.insert(), - [ - {"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)} - ] + [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}], ) def test_via_string(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row['id'], 1 - ) - eq_( - row['data'], "d1" - ) + eq_(row["id"], 1) + eq_(row["data"], "d1") def test_via_int(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row[0], 1 - ) - eq_( - row[1], "d1" - ) + eq_(row[0], 1) + eq_(row[1], "d1") def test_via_col_object(self): row = config.db.execute( - self.tables.plain_pk.select(). - order_by(self.tables.plain_pk.c.id) + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) ).first() - eq_( - row[self.tables.plain_pk.c.id], 1 - ) - eq_( - row[self.tables.plain_pk.c.data], "d1" - ) + eq_(row[self.tables.plain_pk.c.id], 1) + eq_(row[self.tables.plain_pk.c.data], "d1") @requirements.duplicate_names_in_cursor_description def test_row_with_dupe_names(self): result = config.db.execute( - select([self.tables.plain_pk.c.data, - self.tables.plain_pk.c.data.label('data')]). - order_by(self.tables.plain_pk.c.id) + select( + [ + self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label("data"), + ] + ).order_by(self.tables.plain_pk.c.id) ) row = result.first() - eq_(result.keys(), ['data', 'data']) - eq_(row, ('d1', 'd1')) + eq_(result.keys(), ["data", "data"]) + eq_(row, ("d1", "d1")) def test_row_w_scalar_select(self): """test that a scalar select as a column is returned as such @@ -101,11 +91,11 @@ class RowFetchTest(fixtures.TablesTest): """ datetable = self.tables.has_dates - s = select([datetable.alias('x').c.today]).as_scalar() - s2 = select([datetable.c.id, s.label('somelabel')]) + s = select([datetable.alias("x").c.today]).as_scalar() + s2 = select([datetable.c.id, s.label("somelabel")]) row = config.db.execute(s2).first() - eq_(row['somelabel'], datetime.datetime(2006, 5, 12, 12, 0, 0)) + eq_(row["somelabel"], datetime.datetime(2006, 5, 12, 12, 0, 0)) class PercentSchemaNamesTest(fixtures.TablesTest): @@ -117,29 +107,31 @@ class PercentSchemaNamesTest(fixtures.TablesTest): """ - __requires__ = ('percent_schema_names', ) + __requires__ = ("percent_schema_names",) __backend__ = True @classmethod def define_tables(cls, metadata): - cls.tables.percent_table = Table('percent%table', metadata, - Column("percent%", Integer), - Column( - "spaces % more spaces", Integer), - ) + cls.tables.percent_table = Table( + "percent%table", + metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) cls.tables.lightweight_percent_table = sql.table( - 'percent%table', sql.column("percent%"), - sql.column("spaces % more spaces") + "percent%table", + sql.column("percent%"), + sql.column("spaces % more spaces"), ) def test_single_roundtrip(self): percent_table = self.tables.percent_table for params in [ - {'percent%': 5, 'spaces % more spaces': 12}, - {'percent%': 7, 'spaces % more spaces': 11}, - {'percent%': 9, 'spaces % more spaces': 10}, - {'percent%': 11, 'spaces % more spaces': 9} + {"percent%": 5, "spaces % more spaces": 12}, + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, ]: config.db.execute(percent_table.insert(), params) self._assert_table() @@ -147,14 +139,15 @@ class PercentSchemaNamesTest(fixtures.TablesTest): def test_executemany_roundtrip(self): percent_table = self.tables.percent_table config.db.execute( - percent_table.insert(), - {'percent%': 5, 'spaces % more spaces': 12} + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} ) config.db.execute( percent_table.insert(), - [{'percent%': 7, 'spaces % more spaces': 11}, - {'percent%': 9, 'spaces % more spaces': 10}, - {'percent%': 11, 'spaces % more spaces': 9}] + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], ) self._assert_table() @@ -163,85 +156,81 @@ class PercentSchemaNamesTest(fixtures.TablesTest): lightweight_percent_table = self.tables.lightweight_percent_table for table in ( - percent_table, - percent_table.alias(), - lightweight_percent_table, - lightweight_percent_table.alias()): + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias(), + ): eq_( list( config.db.execute( - table.select().order_by(table.c['percent%']) + table.select().order_by(table.c["percent%"]) ) ), - [ - (5, 12), - (7, 11), - (9, 10), - (11, 9) - ] + [(5, 12), (7, 11), (9, 10), (11, 9)], ) eq_( list( config.db.execute( - table.select(). - where(table.c['spaces % more spaces'].in_([9, 10])). - order_by(table.c['percent%']), + table.select() + .where(table.c["spaces % more spaces"].in_([9, 10])) + .order_by(table.c["percent%"]) ) ), - [ - (9, 10), - (11, 9) - ] + [(9, 10), (11, 9)], ) - row = config.db.execute(table.select(). - order_by(table.c['percent%'])).first() - eq_(row['percent%'], 5) - eq_(row['spaces % more spaces'], 12) + row = config.db.execute( + table.select().order_by(table.c["percent%"]) + ).first() + eq_(row["percent%"], 5) + eq_(row["spaces % more spaces"], 12) - eq_(row[table.c['percent%']], 5) - eq_(row[table.c['spaces % more spaces']], 12) + eq_(row[table.c["percent%"]], 5) + eq_(row[table.c["spaces % more spaces"]], 12) config.db.execute( percent_table.update().values( - {percent_table.c['spaces % more spaces']: 15} + {percent_table.c["spaces % more spaces"]: 15} ) ) eq_( list( config.db.execute( - percent_table. - select(). - order_by(percent_table.c['percent%']) + percent_table.select().order_by( + percent_table.c["percent%"] + ) ) ), - [(5, 15), (7, 15), (9, 15), (11, 15)] + [(5, 15), (7, 15), (9, 15), (11, 15)], ) -class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): +class ServerSideCursorsTest( + fixtures.TestBase, testing.AssertsExecutionResults +): - __requires__ = ('server_side_cursors', ) + __requires__ = ("server_side_cursors",) __backend__ = True def _is_server_side(self, cursor): if self.engine.dialect.driver == "psycopg2": return cursor.name - elif self.engine.dialect.driver == 'pymysql': - sscursor = __import__('pymysql.cursors').cursors.SSCursor + elif self.engine.dialect.driver == "pymysql": + sscursor = __import__("pymysql.cursors").cursors.SSCursor return isinstance(cursor, sscursor) elif self.engine.dialect.driver == "mysqldb": - sscursor = __import__('MySQLdb.cursors').cursors.SSCursor + sscursor = __import__("MySQLdb.cursors").cursors.SSCursor return isinstance(cursor, sscursor) else: return False def _fixture(self, server_side_cursors): self.engine = engines.testing_engine( - options={'server_side_cursors': server_side_cursors} + options={"server_side_cursors": server_side_cursors} ) return self.engine @@ -251,12 +240,12 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_global_string(self): engine = self._fixture(True) - result = engine.execute('select 1') + result = engine.execute("select 1") assert self._is_server_side(result.cursor) def test_global_text(self): engine = self._fixture(True) - result = engine.execute(text('select 1')) + result = engine.execute(text("select 1")) assert self._is_server_side(result.cursor) def test_global_expr(self): @@ -266,7 +255,7 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_global_off_explicit(self): engine = self._fixture(False) - result = engine.execute(text('select 1')) + result = engine.execute(text("select 1")) # It should be off globally ... @@ -286,10 +275,11 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): engine = self._fixture(False) # and this one - result = \ - engine.connect().execution_options(stream_results=True).\ - execute('select 1' - ) + result = ( + engine.connect() + .execution_options(stream_results=True) + .execute("select 1") + ) assert self._is_server_side(result.cursor) def test_stmt_enabled_conn_option_disabled(self): @@ -298,9 +288,9 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): s = select([1]).execution_options(stream_results=True) # not this one - result = \ - engine.connect().execution_options(stream_results=False).\ - execute(s) + result = ( + engine.connect().execution_options(stream_results=False).execute(s) + ) assert not self._is_server_side(result.cursor) def test_stmt_option_disabled(self): @@ -329,18 +319,18 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): def test_for_update_string(self): engine = self._fixture(True) - result = engine.execute('SELECT 1 FOR UPDATE') + result = engine.execute("SELECT 1 FOR UPDATE") assert self._is_server_side(result.cursor) def test_text_no_ss(self): engine = self._fixture(False) - s = text('select 42') + s = text("select 42") result = engine.execute(s) assert not self._is_server_side(result.cursor) def test_text_ss_option(self): engine = self._fixture(False) - s = text('select 42').execution_options(stream_results=True) + s = text("select 42").execution_options(stream_results=True) result = engine.execute(s) assert self._is_server_side(result.cursor) @@ -349,19 +339,25 @@ class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults): md = self.metadata engine = self._fixture(True) - test_table = Table('test_table', md, - Column('id', Integer, primary_key=True), - Column('data', String(50))) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) test_table.create(checkfirst=True) - test_table.insert().execute(data='data1') - test_table.insert().execute(data='data2') - eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, 'data1'), (2, 'data2')]) - test_table.update().where( - test_table.c.id == 2).values( - data=test_table.c.data + - ' updated').execute() - eq_(test_table.select().order_by(test_table.c.id).execute().fetchall(), - [(1, 'data1'), (2, 'data2 updated')]) + test_table.insert().execute(data="data1") + test_table.insert().execute(data="data2") + eq_( + test_table.select().order_by(test_table.c.id).execute().fetchall(), + [(1, "data1"), (2, "data2")], + ) + test_table.update().where(test_table.c.id == 2).values( + data=test_table.c.data + " updated" + ).execute() + eq_( + test_table.select().order_by(test_table.c.id).execute().fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) test_table.delete().execute() - eq_(select([func.count('*')]).select_from(test_table).scalar(), 0) + eq_(select([func.count("*")]).select_from(test_table).scalar(), 0) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 73ce02492..032b68eb6 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -16,10 +16,12 @@ class CollateTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(100)) - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) @classmethod def insert_data(cls): @@ -28,26 +30,21 @@ class CollateTest(fixtures.TablesTest): [ {"id": 1, "data": "collate data1"}, {"id": 2, "data": "collate data2"}, - ] + ], ) def _assert_result(self, select, result): - eq_( - config.db.execute(select).fetchall(), - result - ) + eq_(config.db.execute(select).fetchall(), result) @testing.requires.order_by_collation def test_collate_order_by(self): collation = testing.requires.get_order_by_collation(testing.config) self._assert_result( - select([self.tables.some_table]). - order_by(self.tables.some_table.c.data.collate(collation).asc()), - [ - (1, "collate data1"), - (2, "collate data2"), - ] + select([self.tables.some_table]).order_by( + self.tables.some_table.c.data.collate(collation).asc() + ), + [(1, "collate data1"), (2, "collate data2")], ) @@ -59,17 +56,20 @@ class OrderByLabelTest(fixtures.TablesTest): setting. """ + __backend__ = True @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('q', String(50)), - Column('p', String(50)) - ) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + ) @classmethod def insert_data(cls): @@ -79,65 +79,55 @@ class OrderByLabelTest(fixtures.TablesTest): {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, - ] + ], ) def _assert_result(self, select, result): - eq_( - config.db.execute(select).fetchall(), - result - ) + eq_(config.db.execute(select).fetchall(), result) def test_plain(self): table = self.tables.some_table - lx = table.c.x.label('lx') - self._assert_result( - select([lx]).order_by(lx), - [(1, ), (2, ), (3, )] - ) + lx = table.c.x.label("lx") + self._assert_result(select([lx]).order_by(lx), [(1,), (2,), (3,)]) def test_composed_int(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') - self._assert_result( - select([lx]).order_by(lx), - [(3, ), (5, ), (7, )] - ) + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select([lx]).order_by(lx), [(3,), (5,), (7,)]) def test_composed_multiple(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') - ly = (func.lower(table.c.q) + table.c.p).label('ly') + lx = (table.c.x + table.c.y).label("lx") + ly = (func.lower(table.c.q) + table.c.p).label("ly") self._assert_result( select([lx, ly]).order_by(lx, ly.desc()), - [(3, util.u('q1p3')), (5, util.u('q2p2')), (7, util.u('q3p1'))] + [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))], ) def test_plain_desc(self): table = self.tables.some_table - lx = table.c.x.label('lx') + lx = table.c.x.label("lx") self._assert_result( - select([lx]).order_by(lx.desc()), - [(3, ), (2, ), (1, )] + select([lx]).order_by(lx.desc()), [(3,), (2,), (1,)] ) def test_composed_int_desc(self): table = self.tables.some_table - lx = (table.c.x + table.c.y).label('lx') + lx = (table.c.x + table.c.y).label("lx") self._assert_result( - select([lx]).order_by(lx.desc()), - [(7, ), (5, ), (3, )] + select([lx]).order_by(lx.desc()), [(7,), (5,), (3,)] ) @testing.requires.group_by_complex_expression def test_group_by_composed(self): table = self.tables.some_table - expr = (table.c.x + table.c.y).label('lx') - stmt = select([func.count(table.c.id), expr]).group_by(expr).order_by(expr) - self._assert_result( - stmt, - [(1, 3), (1, 5), (1, 7)] + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select([func.count(table.c.id), expr]) + .group_by(expr) + .order_by(expr) ) + self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)]) class LimitOffsetTest(fixtures.TablesTest): @@ -145,10 +135,13 @@ class LimitOffsetTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) @classmethod def insert_data(cls): @@ -159,20 +152,17 @@ class LimitOffsetTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3}, {"id": 3, "x": 3, "y": 4}, {"id": 4, "x": 4, "y": 5}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_simple_limit(self): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).limit(2), - [(1, 1, 2), (2, 2, 3)] + [(1, 1, 2), (2, 2, 3)], ) @testing.requires.offset @@ -180,7 +170,7 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).offset(2), - [(3, 3, 4), (4, 4, 5)] + [(3, 3, 4), (4, 4, 5)], ) @testing.requires.offset @@ -188,7 +178,7 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table self._assert_result( select([table]).order_by(table.c.id).limit(2).offset(1), - [(2, 2, 3), (3, 3, 4)] + [(2, 2, 3), (3, 3, 4)], ) @testing.requires.offset @@ -198,41 +188,40 @@ class LimitOffsetTest(fixtures.TablesTest): table = self.tables.some_table stmt = select([table]).order_by(table.c.id).limit(2).offset(1) sql = stmt.compile( - dialect=config.db.dialect, - compile_kwargs={"literal_binds": True}) + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) sql = str(sql) - self._assert_result( - sql, - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(sql, [(2, 2, 3), (3, 3, 4)]) @testing.requires.bound_limit_offset def test_bound_limit(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id).limit(bindparam('l')), + select([table]).order_by(table.c.id).limit(bindparam("l")), [(1, 1, 2), (2, 2, 3)], - params={"l": 2} + params={"l": 2}, ) @testing.requires.bound_limit_offset def test_bound_offset(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id).offset(bindparam('o')), + select([table]).order_by(table.c.id).offset(bindparam("o")), [(3, 3, 4), (4, 4, 5)], - params={"o": 2} + params={"o": 2}, ) @testing.requires.bound_limit_offset def test_bound_limit_offset(self): table = self.tables.some_table self._assert_result( - select([table]).order_by(table.c.id). - limit(bindparam("l")).offset(bindparam("o")), + select([table]) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), [(2, 2, 3), (3, 3, 4)], - params={"l": 2, "o": 1} + params={"l": 2, "o": 1}, ) @@ -241,10 +230,13 @@ class CompoundSelectTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer)) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) @classmethod def insert_data(cls): @@ -255,14 +247,11 @@ class CompoundSelectTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3}, {"id": 3, "x": 3, "y": 4}, {"id": 4, "x": 4, "y": 5}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_plain_union(self): table = self.tables.some_table @@ -270,10 +259,7 @@ class CompoundSelectTest(fixtures.TablesTest): s2 = select([table]).where(table.c.id == 3) u1 = union(s1, s2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) def test_select_from_plain_union(self): table = self.tables.some_table @@ -281,80 +267,88 @@ class CompoundSelectTest(fixtures.TablesTest): s2 = select([table]).where(table.c.id == 3) u1 = union(s1, s2).alias().select() - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.order_by_col_from_union @testing.requires.parens_in_union_contained_select_w_limit_offset def test_limit_offset_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + ) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.parens_in_union_contained_select_wo_limit_offset def test_order_by_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - order_by(table.c.id) + s1 = select([table]).where(table.c.id == 2).order_by(table.c.id) + s2 = select([table]).where(table.c.id == 3).order_by(table.c.id) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) def test_distinct_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - distinct() - s2 = select([table]).where(table.c.id == 3).\ - distinct() + s1 = select([table]).where(table.c.id == 2).distinct() + s2 = select([table]).where(table.c.id == 3).distinct() u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) @testing.requires.parens_in_union_contained_select_w_limit_offset def test_limit_offset_in_unions_from_alias(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id) - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + ) # this necessarily has double parens u1 = union(s1, s2).alias() self._assert_result( - u1.select().limit(2).order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] + u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] ) def test_limit_offset_aliased_selectable_in_unions(self): table = self.tables.some_table - s1 = select([table]).where(table.c.id == 2).\ - limit(1).order_by(table.c.id).alias().select() - s2 = select([table]).where(table.c.id == 3).\ - limit(1).order_by(table.c.id).alias().select() + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) u1 = union(s1, s2).limit(2) - self._assert_result( - u1.order_by(u1.c.id), - [(2, 2, 3), (3, 3, 4)] - ) + self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) class ExpandingBoundInTest(fixtures.TablesTest): @@ -362,11 +356,14 @@ class ExpandingBoundInTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('x', Integer), - Column('y', Integer), - Column('z', String(50))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) @classmethod def insert_data(cls): @@ -377,178 +374,184 @@ class ExpandingBoundInTest(fixtures.TablesTest): {"id": 2, "x": 2, "y": 3, "z": "z2"}, {"id": 3, "x": 3, "y": 4, "z": "z3"}, {"id": 4, "x": 4, "y": 5, "z": "z4"}, - ] + ], ) def _assert_result(self, select, result, params=()): - eq_( - config.db.execute(select, params).fetchall(), - result - ) + eq_(config.db.execute(select, params).fetchall(), result) def test_multiple_empty_sets(self): # test that any anonymous aliasing used by the dialect # is fine with duplicates table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).where( - table.c.y.in_(bindparam('p', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": [], "p": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .where(table.c.y.in_(bindparam("p", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": [], "p": []}) + @testing.requires.tuple_in def test_empty_heterogeneous_tuples(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.z).in_( - bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + @testing.requires.tuple_in def test_empty_homogeneous_tuples(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.y).in_( - bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_bound_in_scalar(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [(2, ), (3, ), (4, )], - params={"q": [2, 3, 4]}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) + @testing.requires.tuple_in def test_bound_in_two_tuple(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.y).in_( - bindparam('q', expanding=True))).order_by(table.c.id) + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) + ) self._assert_result( - stmt, - [(2, ), (3, ), (4, )], - params={"q": [(2, 3), (3, 4), (4, 5)]}, + stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} ) @testing.requires.tuple_in def test_bound_in_heterogeneous_two_tuple(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - tuple_(table.c.x, table.c.z).in_( - bindparam('q', expanding=True))).order_by(table.c.id) + stmt = ( + select([table.c.id]) + .where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True) + ) + ) + .order_by(table.c.id) + ) self._assert_result( stmt, - [(2, ), (3, ), (4, )], + [(2,), (3,), (4,)], params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) def test_empty_set_against_integer(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_empty_set_against_integer_negation(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.x.notin_(bindparam('q', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [(1, ), (2, ), (3, ), (4, )], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.x.notin_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + def test_empty_set_against_string(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.z.in_(bindparam('q', expanding=True))).order_by(table.c.id) - - self._assert_result( - stmt, - [], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.z.in_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [], params={"q": []}) + def test_empty_set_against_string_negation(self): table = self.tables.some_table - stmt = select([table.c.id]).where( - table.c.z.notin_(bindparam('q', expanding=True)) - ).order_by(table.c.id) - - self._assert_result( - stmt, - [(1, ), (2, ), (3, ), (4, )], - params={"q": []}, + stmt = ( + select([table.c.id]) + .where(table.c.z.notin_(bindparam("q", expanding=True))) + .order_by(table.c.id) ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + def test_null_in_empty_set_is_false(self): - stmt = select([ - case( - [ - ( - null().in_(bindparam('foo', value=(), expanding=True)), - true() - ) - ], - else_=false() - ) - ]) - in_( - config.db.execute(stmt).fetchone()[0], - (False, 0) + stmt = select( + [ + case( + [ + ( + null().in_( + bindparam("foo", value=(), expanding=True) + ), + true(), + ) + ], + else_=false(), + ) + ] ) + in_(config.db.execute(stmt).fetchone()[0], (False, 0)) class LikeFunctionsTest(fixtures.TablesTest): __backend__ = True - run_inserts = 'once' + run_inserts = "once" run_deletes = None @classmethod def define_tables(cls, metadata): - Table("some_table", metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50))) + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) @classmethod def insert_data(cls): @@ -565,7 +568,7 @@ class LikeFunctionsTest(fixtures.TablesTest): {"id": 8, "data": "ab9cdefg"}, {"id": 9, "data": "abcde#fg"}, {"id": 10, "data": "abcd9fg"}, - ] + ], ) def _test(self, expr, expected): @@ -573,8 +576,10 @@ class LikeFunctionsTest(fixtures.TablesTest): with config.db.connect() as conn: rows = { - value for value, in - conn.execute(select([some_table.c.id]).where(expr)) + value + for value, in conn.execute( + select([some_table.c.id]).where(expr) + ) } eq_(rows, expected) @@ -591,7 +596,8 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test( col.startswith(literal_column("'ab%c'")), - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + ) def test_startswith_escape(self): col = self.tables.some_table.c.data @@ -608,8 +614,9 @@ class LikeFunctionsTest(fixtures.TablesTest): def test_endswith_sqlexpr(self): col = self.tables.some_table.c.data - self._test(col.endswith(literal_column("'e%fg'")), - {1, 2, 3, 4, 5, 6, 7, 8, 9}) + self._test( + col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} + ) def test_endswith_autoescape(self): col = self.tables.some_table.c.data @@ -640,5 +647,3 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) - - diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index f1c00de6b..15a850fe9 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -9,140 +9,144 @@ from ..schema import Table, Column class SequenceTest(fixtures.TablesTest): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True - run_create_tables = 'each' + run_create_tables = "each" @classmethod def define_tables(cls, metadata): - Table('seq_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq'), primary_key=True), - Column('data', String(50)) - ) + Table( + "seq_pk", + metadata, + Column("id", Integer, Sequence("tab_id_seq"), primary_key=True), + Column("data", String(50)), + ) - Table('seq_opt_pk', metadata, - Column('id', Integer, Sequence('tab_id_seq', optional=True), - primary_key=True), - Column('data', String(50)) - ) + Table( + "seq_opt_pk", + metadata, + Column( + "id", + Integer, + Sequence("tab_id_seq", optional=True), + primary_key=True, + ), + Column("data", String(50)), + ) def test_insert_roundtrip(self): - config.db.execute( - self.tables.seq_pk.insert(), - data="some data" - ) + config.db.execute(self.tables.seq_pk.insert(), data="some data") self._assert_round_trip(self.tables.seq_pk, config.db) def test_insert_lastrowid(self): - r = config.db.execute( - self.tables.seq_pk.insert(), - data="some data" - ) - eq_( - r.inserted_primary_key, - [1] - ) + r = config.db.execute(self.tables.seq_pk.insert(), data="some data") + eq_(r.inserted_primary_key, [1]) def test_nextval_direct(self): - r = config.db.execute( - self.tables.seq_pk.c.id.default - ) - eq_( - r, 1 - ) + r = config.db.execute(self.tables.seq_pk.c.id.default) + eq_(r, 1) @requirements.sequences_optional def test_optional_seq(self): r = config.db.execute( - self.tables.seq_opt_pk.insert(), - data="some data" - ) - eq_( - r.inserted_primary_key, - [1] + self.tables.seq_opt_pk.insert(), data="some data" ) + eq_(r.inserted_primary_key, [1]) def _assert_round_trip(self, table, conn): row = conn.execute(table.select()).first() - eq_( - row, - (1, "some data") - ) + eq_(row, (1, "some data")) class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): - __requires__ = ('sequences',) + __requires__ = ("sequences",) __backend__ = True def test_literal_binds_inline_compile(self): table = Table( - 'x', MetaData(), - Column('y', Integer, Sequence('y_seq')), - Column('q', Integer)) + "x", + MetaData(), + Column("y", Integer, Sequence("y_seq")), + Column("q", Integer), + ) stmt = table.insert().values(q=5) seq_nextval = testing.db.dialect.statement_compiler( - statement=None, dialect=testing.db.dialect).visit_sequence( - Sequence("y_seq")) + statement=None, dialect=testing.db.dialect + ).visit_sequence(Sequence("y_seq")) self.assert_compile( stmt, - "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ), + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,), literal_binds=True, - dialect=testing.db.dialect) + dialect=testing.db.dialect, + ) class HasSequenceTest(fixtures.TestBase): - __requires__ = 'sequences', + __requires__ = ("sequences",) __backend__ = True def test_has_sequence(self): - s1 = Sequence('user_id_seq') + s1 = Sequence("user_id_seq") testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, - 'user_id_seq'), True) + eq_( + testing.db.dialect.has_sequence(testing.db, "user_id_seq"), + True, + ) finally: testing.db.execute(schema.DropSequence(s1)) @testing.requires.schemas def test_has_sequence_schema(self): - s1 = Sequence('user_id_seq', schema=config.test_schema) + s1 = Sequence("user_id_seq", schema=config.test_schema) testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence( - testing.db, 'user_id_seq', schema=config.test_schema), True) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + True, + ) finally: testing.db.execute(schema.DropSequence(s1)) def test_has_sequence_neg(self): - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), - False) + eq_(testing.db.dialect.has_sequence(testing.db, "user_id_seq"), False) @testing.requires.schemas def test_has_sequence_schemas_neg(self): - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema=config.test_schema), - False) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + False, + ) @testing.requires.schemas def test_has_sequence_default_not_in_remote(self): - s1 = Sequence('user_id_seq') + s1 = Sequence("user_id_seq") testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq', - schema=config.test_schema), - False) + eq_( + testing.db.dialect.has_sequence( + testing.db, "user_id_seq", schema=config.test_schema + ), + False, + ) finally: testing.db.execute(schema.DropSequence(s1)) @testing.requires.schemas def test_has_sequence_remote_not_in_default(self): - s1 = Sequence('user_id_seq', schema=config.test_schema) + s1 = Sequence("user_id_seq", schema=config.test_schema) testing.db.execute(schema.CreateSequence(s1)) try: - eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), - False) + eq_( + testing.db.dialect.has_sequence(testing.db, "user_id_seq"), + False, + ) finally: testing.db.execute(schema.DropSequence(s1)) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 27c7bb115..6dfb80915 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -4,9 +4,24 @@ from .. import fixtures, config from ..assertions import eq_ from ..config import requirements from sqlalchemy import Integer, Unicode, UnicodeText, select, TIMESTAMP -from sqlalchemy import Date, DateTime, Time, MetaData, String, \ - Text, Numeric, Float, literal, Boolean, cast, null, JSON, and_, \ - type_coerce, BigInteger +from sqlalchemy import ( + Date, + DateTime, + Time, + MetaData, + String, + Text, + Numeric, + Float, + literal, + Boolean, + cast, + null, + JSON, + and_, + type_coerce, + BigInteger, +) from ..schema import Table, Column from ... import testing import decimal @@ -24,13 +39,17 @@ class _LiteralRoundTripFixture(object): # into a typed column. we can then SELECT it back as its # official type; ideally we'd be able to use CAST here # but MySQL in particular can't CAST fully - t = Table('t', self.metadata, Column('x', type_)) + t = Table("t", self.metadata, Column("x", type_)) t.create() for value in input_: - ins = t.insert().values(x=literal(value)).compile( - dialect=testing.db.dialect, - compile_kwargs=dict(literal_binds=True) + ins = ( + t.insert() + .values(x=literal(value)) + .compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) ) testing.db.execute(ins) @@ -42,40 +61,33 @@ class _LiteralRoundTripFixture(object): class _UnicodeFixture(_LiteralRoundTripFixture): - __requires__ = 'unicode_data', + __requires__ = ("unicode_data",) - data = u("Alors vous imaginez ma surprise, au lever du jour, " - "quand une drôle de petite voix m’a réveillé. Elle " - "disait: « S’il vous plaît… dessine-moi un mouton! »") + data = u( + "Alors vous imaginez ma surprise, au lever du jour, " + "quand une drôle de petite voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi un mouton! »" + ) @classmethod def define_tables(cls, metadata): - Table('unicode_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('unicode_data', cls.datatype), - ) + Table( + "unicode_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("unicode_data", cls.datatype), + ) def test_round_trip(self): unicode_table = self.tables.unicode_table - config.db.execute( - unicode_table.insert(), - { - 'unicode_data': self.data, - } - ) + config.db.execute(unicode_table.insert(), {"unicode_data": self.data}) - row = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) - ).first() + row = config.db.execute(select([unicode_table.c.unicode_data])).first() - eq_( - row, - (self.data, ) - ) + eq_(row, (self.data,)) assert isinstance(row[0], util.text_type) def test_round_trip_executemany(self): @@ -83,44 +95,29 @@ class _UnicodeFixture(_LiteralRoundTripFixture): config.db.execute( unicode_table.insert(), - [ - { - 'unicode_data': self.data, - } - for i in range(3) - ] + [{"unicode_data": self.data} for i in range(3)], ) rows = config.db.execute( - select([ - unicode_table.c.unicode_data, - ]) + select([unicode_table.c.unicode_data]) ).fetchall() - eq_( - rows, - [(self.data, ) for i in range(3)] - ) + eq_(rows, [(self.data,) for i in range(3)]) for row in rows: assert isinstance(row[0], util.text_type) def _test_empty_strings(self): unicode_table = self.tables.unicode_table - config.db.execute( - unicode_table.insert(), - {"unicode_data": u('')} - ) - row = config.db.execute( - select([unicode_table.c.unicode_data]) - ).first() - eq_(row, (u(''),)) + config.db.execute(unicode_table.insert(), {"unicode_data": u("")}) + row = config.db.execute(select([unicode_table.c.unicode_data])).first() + eq_(row, (u(""),)) def test_literal(self): self._literal_round_trip(self.datatype, [self.data], [self.data]) class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): - __requires__ = 'unicode_data', + __requires__ = ("unicode_data",) __backend__ = True datatype = Unicode(255) @@ -131,7 +128,7 @@ class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): - __requires__ = 'unicode_data', 'text_type' + __requires__ = "unicode_data", "text_type" __backend__ = True datatype = UnicodeText() @@ -142,54 +139,47 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): - __requires__ = 'text_type', + __requires__ = ("text_type",) __backend__ = True @classmethod def define_tables(cls, metadata): - Table('text_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('text_data', Text), - ) + Table( + "text_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("text_data", Text), + ) def test_text_roundtrip(self): text_table = self.tables.text_table - config.db.execute( - text_table.insert(), - {"text_data": 'some text'} - ) - row = config.db.execute( - select([text_table.c.text_data]) - ).first() - eq_(row, ('some text',)) + config.db.execute(text_table.insert(), {"text_data": "some text"}) + row = config.db.execute(select([text_table.c.text_data])).first() + eq_(row, ("some text",)) def test_text_empty_strings(self): text_table = self.tables.text_table - config.db.execute( - text_table.insert(), - {"text_data": ''} - ) - row = config.db.execute( - select([text_table.c.text_data]) - ).first() - eq_(row, ('',)) + config.db.execute(text_table.insert(), {"text_data": ""}) + row = config.db.execute(select([text_table.c.text_data])).first() + eq_(row, ("",)) def test_literal(self): self._literal_round_trip(Text, ["some text"], ["some text"]) def test_literal_quoting(self): - data = '''some 'text' hey "hi there" that's text''' + data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(Text, [data], [data]) def test_literal_backslashes(self): - data = r'backslash one \ backslash two \\ end' + data = r"backslash one \ backslash two \\ end" self._literal_round_trip(Text, [data], [data]) def test_literal_percentsigns(self): - data = r'percent % signs %% percent' + data = r"percent % signs %% percent" self._literal_round_trip(Text, [data], [data]) @@ -199,9 +189,7 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): @requirements.unbounded_varchar def test_nolength_string(self): metadata = MetaData() - foo = Table('foo', metadata, - Column('one', String) - ) + foo = Table("foo", metadata, Column("one", String)) foo.create(config.db) foo.drop(config.db) @@ -210,11 +198,11 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): self._literal_round_trip(String(40), ["some text"], ["some text"]) def test_literal_quoting(self): - data = '''some 'text' hey "hi there" that's text''' + data = """some 'text' hey "hi there" that's text""" self._literal_round_trip(String(40), [data], [data]) def test_literal_backslashes(self): - data = r'backslash one \ backslash two \\ end' + data = r"backslash one \ backslash two \\ end" self._literal_round_trip(String(40), [data], [data]) @@ -223,44 +211,32 @@ class _DateFixture(_LiteralRoundTripFixture): @classmethod def define_tables(cls, metadata): - Table('date_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('date_data', cls.datatype), - ) + Table( + "date_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date_data", cls.datatype), + ) def test_round_trip(self): date_table = self.tables.date_table - config.db.execute( - date_table.insert(), - {'date_data': self.data} - ) + config.db.execute(date_table.insert(), {"date_data": self.data}) - row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + row = config.db.execute(select([date_table.c.date_data])).first() compare = self.compare or self.data - eq_(row, - (compare, )) + eq_(row, (compare,)) assert isinstance(row[0], type(compare)) def test_null(self): date_table = self.tables.date_table - config.db.execute( - date_table.insert(), - {'date_data': None} - ) + config.db.execute(date_table.insert(), {"date_data": None}) - row = config.db.execute( - select([ - date_table.c.date_data, - ]) - ).first() + row = config.db.execute(select([date_table.c.date_data])).first() eq_(row, (None,)) @testing.requires.datetime_literals @@ -270,48 +246,49 @@ class _DateFixture(_LiteralRoundTripFixture): class DateTimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime', + __requires__ = ("datetime",) __backend__ = True datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18) class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime_microseconds', + __requires__ = ("datetime_microseconds",) __backend__ = True datatype = DateTime data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'timestamp_microseconds', + __requires__ = ("timestamp_microseconds",) __backend__ = True datatype = TIMESTAMP data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) class TimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'time', + __requires__ = ("time",) __backend__ = True datatype = Time data = datetime.time(12, 57, 18) class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'time_microseconds', + __requires__ = ("time_microseconds",) __backend__ = True datatype = Time data = datetime.time(12, 57, 18, 396) class DateTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date', + __requires__ = ("date",) __backend__ = True datatype = Date data = datetime.date(2012, 10, 15) class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date', 'date_coerces_from_datetime' + __requires__ = "date", "date_coerces_from_datetime" __backend__ = True datatype = Date data = datetime.datetime(2012, 10, 15, 12, 57, 18) @@ -319,14 +296,14 @@ class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'datetime_historic', + __requires__ = ("datetime_historic",) __backend__ = True datatype = DateTime data = datetime.datetime(1850, 11, 10, 11, 52, 35) class DateHistoricTest(_DateFixture, fixtures.TablesTest): - __requires__ = 'date_historic', + __requires__ = ("date_historic",) __backend__ = True datatype = Date data = datetime.date(1727, 4, 1) @@ -345,26 +322,21 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): def _round_trip(self, datatype, data): metadata = self.metadata int_table = Table( - 'integer_table', metadata, - Column('id', Integer, primary_key=True, - test_needs_autoincrement=True), - Column('integer_data', datatype), + "integer_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("integer_data", datatype), ) metadata.create_all(config.db) - config.db.execute( - int_table.insert(), - {'integer_data': data} - ) + config.db.execute(int_table.insert(), {"integer_data": data}) - row = config.db.execute( - select([ - int_table.c.integer_data, - ]) - ).first() + row = config.db.execute(select([int_table.c.integer_data])).first() - eq_(row, (data, )) + eq_(row, (data,)) if util.py3k: assert isinstance(row[0], int) @@ -377,12 +349,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.emits_warning(r".*does \*not\* support Decimal objects natively") @testing.provide_metadata - def _do_test(self, type_, input_, output, - filter_=None, check_scale=False): + def _do_test(self, type_, input_, output, filter_=None, check_scale=False): metadata = self.metadata - t = Table('t', metadata, Column('x', type_)) + t = Table("t", metadata, Column("x", type_)) t.create() - t.insert().execute([{'x': x} for x in input_]) + t.insert().execute([{"x": x} for x in input_]) result = {row[0] for row in t.select().execute()} output = set(output) @@ -391,10 +362,7 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): output = set(filter_(x) for x in output) eq_(result, output) if check_scale: - eq_( - [str(x) for x in result], - [str(x) for x in output], - ) + eq_([str(x) for x in result], [str(x) for x in output]) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def test_render_literal_numeric(self): @@ -416,8 +384,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): self._literal_round_trip( Float(4), [15.7563, decimal.Decimal("15.7563")], - [15.7563, ], - filter_=lambda n: n is not None and round(n, 5) or None + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, ) @testing.requires.precision_generic_float_type @@ -425,8 +393,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): self._do_test( Float(None, decimal_return_scale=7, asdecimal=True), [15.7563827, decimal.Decimal("15.7563827")], - [decimal.Decimal("15.7563827"), ], - check_scale=True + [decimal.Decimal("15.7563827")], + check_scale=True, ) def test_numeric_as_decimal(self): @@ -445,18 +413,12 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.requires.fetch_null_from_numeric def test_numeric_null_as_decimal(self): - self._do_test( - Numeric(precision=8, scale=4), - [None], - [None], - ) + self._do_test(Numeric(precision=8, scale=4), [None], [None]) @testing.requires.fetch_null_from_numeric def test_numeric_null_as_float(self): self._do_test( - Numeric(precision=8, scale=4, asdecimal=False), - [None], - [None], + Numeric(precision=8, scale=4, asdecimal=False), [None], [None] ) @testing.requires.floats_to_four_decimals @@ -472,15 +434,13 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): Float(precision=8), [15.7563, decimal.Decimal("15.7563")], [15.7563], - filter_=lambda n: n is not None and round(n, 5) or None + filter_=lambda n: n is not None and round(n, 5) or None, ) def test_float_coerce_round_trip(self): expr = 15.7563 - val = testing.db.scalar( - select([literal(expr)]) - ) + val = testing.db.scalar(select([literal(expr)])) eq_(val, expr) # this does not work in MySQL, see #4036, however we choose not @@ -491,34 +451,28 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): def test_decimal_coerce_round_trip(self): expr = decimal.Decimal("15.7563") - val = testing.db.scalar( - select([literal(expr)]) - ) + val = testing.db.scalar(select([literal(expr)])) eq_(val, expr) @testing.emits_warning(r".*does \*not\* support Decimal objects natively") def test_decimal_coerce_round_trip_w_cast(self): expr = decimal.Decimal("15.7563") - val = testing.db.scalar( - select([cast(expr, Numeric(10, 4))]) - ) + val = testing.db.scalar(select([cast(expr, Numeric(10, 4))])) eq_(val, expr) @testing.requires.precision_numerics_general def test_precision_decimal(self): - numbers = set([ - decimal.Decimal("54.234246451650"), - decimal.Decimal("0.004354"), - decimal.Decimal("900.0"), - ]) - - self._do_test( - Numeric(precision=18, scale=12), - numbers, - numbers, + numbers = set( + [ + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ] ) + self._do_test(Numeric(precision=18, scale=12), numbers, numbers) + @testing.requires.precision_numerics_enotation_large def test_enotation_decimal(self): """test exceedingly small decimals. @@ -528,25 +482,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set([ - decimal.Decimal('1E-2'), - decimal.Decimal('1E-3'), - decimal.Decimal('1E-4'), - decimal.Decimal('1E-5'), - decimal.Decimal('1E-6'), - decimal.Decimal('1E-7'), - decimal.Decimal('1E-8'), - decimal.Decimal("0.01000005940696"), - decimal.Decimal("0.00000005940696"), - decimal.Decimal("0.00000000000696"), - decimal.Decimal("0.70000000000696"), - decimal.Decimal("696E-12"), - ]) - self._do_test( - Numeric(precision=18, scale=14), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + ] ) + self._do_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large def test_enotation_decimal_large(self): @@ -554,41 +506,32 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set([ - decimal.Decimal('4E+8'), - decimal.Decimal("5748E+15"), - decimal.Decimal('1.521E+15'), - decimal.Decimal('00000000000000.1E+12'), - ]) - self._do_test( - Numeric(precision=25, scale=2), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + ] ) + self._do_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits def test_many_significant_digits(self): - numbers = set([ - decimal.Decimal("31943874831932418390.01"), - decimal.Decimal("319438950232418390.273596"), - decimal.Decimal("87673.594069654243"), - ]) - self._do_test( - Numeric(precision=38, scale=12), - numbers, - numbers + numbers = set( + [ + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + ] ) + self._do_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits def test_numeric_no_decimal(self): - numbers = set([ - decimal.Decimal("1.000") - ]) + numbers = set([decimal.Decimal("1.000")]) self._do_test( - Numeric(precision=5, scale=3), - numbers, - numbers, - check_scale=True + Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -597,42 +540,32 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - Table('boolean_table', metadata, - Column('id', Integer, primary_key=True, autoincrement=False), - Column('value', Boolean), - Column('unconstrained_value', Boolean(create_constraint=False)), - ) + Table( + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), + ) def test_render_literal_bool(self): - self._literal_round_trip( - Boolean(), - [True, False], - [True, False] - ) + self._literal_round_trip(Boolean(), [True, False], [True, False]) def test_round_trip(self): boolean_table = self.tables.boolean_table config.db.execute( boolean_table.insert(), - { - 'id': 1, - 'value': True, - 'unconstrained_value': False - } + {"id": 1, "value": True, "unconstrained_value": False}, ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) + select( + [boolean_table.c.value, boolean_table.c.unconstrained_value] + ) ).first() - eq_( - row, - (True, False) - ) + eq_(row, (True, False)) assert isinstance(row[0], bool) def test_null(self): @@ -640,24 +573,16 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): config.db.execute( boolean_table.insert(), - { - 'id': 1, - 'value': None, - 'unconstrained_value': None - } + {"id": 1, "value": None, "unconstrained_value": None}, ) row = config.db.execute( - select([ - boolean_table.c.value, - boolean_table.c.unconstrained_value - ]) + select( + [boolean_table.c.value, boolean_table.c.unconstrained_value] + ) ).first() - eq_( - row, - (None, None) - ) + eq_(row, (None, None)) def test_whereclause(self): # testing "WHERE <column>" renders a compatible expression @@ -667,92 +592,82 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): conn.execute( boolean_table.insert(), [ - {'id': 1, 'value': True, 'unconstrained_value': True}, - {'id': 2, 'value': False, 'unconstrained_value': False} - ] + {"id": 1, "value": True, "unconstrained_value": True}, + {"id": 2, "value": False, "unconstrained_value": False}, + ], ) eq_( conn.scalar( select([boolean_table.c.id]).where(boolean_table.c.value) ), - 1 + 1, ) eq_( conn.scalar( select([boolean_table.c.id]).where( - boolean_table.c.unconstrained_value) + boolean_table.c.unconstrained_value + ) ), - 1 + 1, ) eq_( conn.scalar( select([boolean_table.c.id]).where(~boolean_table.c.value) ), - 2 + 2, ) eq_( conn.scalar( select([boolean_table.c.id]).where( - ~boolean_table.c.unconstrained_value) + ~boolean_table.c.unconstrained_value + ) ), - 2 + 2, ) - - class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): - __requires__ = 'json_type', + __requires__ = ("json_type",) __backend__ = True datatype = JSON - data1 = { - "key1": "value1", - "key2": "value2" - } + data1 = {"key1": "value1", "key2": "value2"} data2 = { "Key 'One'": "value1", "key two": "value2", - "key three": "value ' three '" + "key three": "value ' three '", } data3 = { "key1": [1, 2, 3], "key2": ["one", "two", "three"], - "key3": [{"four": "five"}, {"six": "seven"}] + "key3": [{"four": "five"}, {"six": "seven"}], } data4 = ["one", "two", "three"] data5 = { "nested": { - "elem1": [ - {"a": "b", "c": "d"}, - {"e": "f", "g": "h"} - ], - "elem2": { - "elem3": {"elem4": "elem5"} - } + "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}], + "elem2": {"elem3": {"elem4": "elem5"}}, } } - data6 = { - "a": 5, - "b": "some value", - "c": {"foo": "bar"} - } + data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}} @classmethod def define_tables(cls, metadata): - Table('data_table', metadata, - Column('id', Integer, primary_key=True), - Column('name', String(30), nullable=False), - Column('data', cls.datatype), - Column('nulldata', cls.datatype(none_as_null=True)) - ) + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype), + Column("nulldata", cls.datatype(none_as_null=True)), + ) def test_round_trip_data1(self): self._test_round_trip(self.data1) @@ -761,99 +676,82 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): data_table = self.tables.data_table config.db.execute( - data_table.insert(), - {'name': 'row1', 'data': data_element} + data_table.insert(), {"name": "row1", "data": data_element} ) - row = config.db.execute( - select([ - data_table.c.data, - ]) - ).first() + row = config.db.execute(select([data_table.c.data])).first() - eq_(row, (data_element, )) + eq_(row, (data_element,)) def test_round_trip_none_as_sql_null(self): - col = self.tables.data_table.c['nulldata'] + col = self.tables.data_table.c["nulldata"] with config.db.connect() as conn: conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": None} + self.tables.data_table.insert(), {"name": "r1", "data": None} ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(col.is_(null())) + select([self.tables.data_table.c.name]).where( + col.is_(null()) + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def test_round_trip_json_null_as_json_null(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] with config.db.connect() as conn: conn.execute( self.tables.data_table.insert(), - {"name": "r1", "data": JSON.NULL} + {"name": "r1", "data": JSON.NULL}, ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(cast(col, String) == 'null') + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def test_round_trip_none_as_json_null(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] with config.db.connect() as conn: conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": None} + self.tables.data_table.insert(), {"name": "r1", "data": None} ) eq_( conn.scalar( - select([self.tables.data_table.c.name]). - where(cast(col, String) == 'null') + select([self.tables.data_table.c.name]).where( + cast(col, String) == "null" + ) ), - "r1" + "r1", ) - eq_( - conn.scalar( - select([col]) - ), - None - ) + eq_(conn.scalar(select([col])), None) def _criteria_fixture(self): config.db.execute( self.tables.data_table.insert(), - [{"name": "r1", "data": self.data1}, - {"name": "r2", "data": self.data2}, - {"name": "r3", "data": self.data3}, - {"name": "r4", "data": self.data4}, - {"name": "r5", "data": self.data5}, - {"name": "r6", "data": self.data6}] + [ + {"name": "r1", "data": self.data1}, + {"name": "r2", "data": self.data2}, + {"name": "r3", "data": self.data3}, + {"name": "r4", "data": self.data4}, + {"name": "r5", "data": self.data5}, + {"name": "r6", "data": self.data6}, + ], ) def _test_index_criteria(self, crit, expected, test_literal=True): @@ -861,20 +759,20 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): with config.db.connect() as conn: stmt = select([self.tables.data_table.c.name]).where(crit) - eq_( - conn.scalar(stmt), - expected - ) + eq_(conn.scalar(stmt), expected) if test_literal: - literal_sql = str(stmt.compile( - config.db, compile_kwargs={"literal_binds": True})) + literal_sql = str( + stmt.compile( + config.db, compile_kwargs={"literal_binds": True} + ) + ) eq_(conn.scalar(literal_sql), expected) def test_crit_spaces_in_key(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] # limit the rows here to avoid PG error # "cannot extract field from a non-object", which is @@ -882,76 +780,74 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): self._test_index_criteria( and_( name.in_(["r1", "r2", "r3"]), - cast(col["key two"], String) == '"value2"' + cast(col["key two"], String) == '"value2"', ), - "r2" + "r2", ) @config.requirements.json_array_indexes def test_crit_simple_int(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] # limit the rows here to avoid PG error # "cannot extract array element from a non-array", which is # fixed in 9.4 but may exist in 9.3 self._test_index_criteria( - and_(name == 'r4', cast(col[1], String) == '"two"'), - "r4" + and_(name == "r4", cast(col[1], String) == '"two"'), "r4" ) def test_crit_mixed_path(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - cast(col[("key3", 1, "six")], String) == '"seven"', - "r3" + cast(col[("key3", 1, "six")], String) == '"seven"', "r3" ) def test_crit_string_path(self): - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( cast(col[("nested", "elem2", "elem3", "elem4")], String) == '"elem5"', - "r5" + "r5", ) def test_crit_against_string_basic(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["b"], String) == '"some value"'), - "r6" + and_(name == "r6", cast(col["b"], String) == '"some value"'), "r6" ) def test_crit_against_string_coerce_type(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', - cast(col["b"], String) == type_coerce("some value", JSON)), + and_( + name == "r6", + cast(col["b"], String) == type_coerce("some value", JSON), + ), "r6", - test_literal=False + test_literal=False, ) def test_crit_against_int_basic(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["a"], String) == '5'), - "r6" + and_(name == "r6", cast(col["a"], String) == "5"), "r6" ) def test_crit_against_int_coerce_type(self): name = self.tables.data_table.c.name - col = self.tables.data_table.c['data'] + col = self.tables.data_table.c["data"] self._test_index_criteria( - and_(name == 'r6', cast(col["a"], String) == type_coerce(5, JSON)), + and_(name == "r6", cast(col["a"], String) == type_coerce(5, JSON)), "r6", - test_literal=False + test_literal=False, ) def test_unicode_round_trip(self): @@ -961,17 +857,17 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): { "name": "r1", "data": { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} - } - } + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, + }, + }, ) eq_( conn.scalar(select([self.tables.data_table.c.data])), { - util.u('réveillé'): util.u('réveillé'), - "data": {"k1": util.u('drôle')} + util.u("réveillé"): util.u("réveillé"), + "data": {"k1": util.u("drôle")}, }, ) @@ -986,7 +882,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): s = Session(testing.db) - d1 = Data(name='d1', data=None, nulldata=None) + d1 = Data(name="d1", data=None, nulldata=None) s.add(d1) s.commit() @@ -995,24 +891,46 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): ) eq_( s.query( - cast(self.tables.data_table.c.data, String(convert_unicode="force")), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd1').first(), - ("null", None) + cast( + self.tables.data_table.c.data, + String(convert_unicode="force"), + ), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), ) eq_( s.query( - cast(self.tables.data_table.c.data, String(convert_unicode="force")), - cast(self.tables.data_table.c.nulldata, String) - ).filter(self.tables.data_table.c.name == 'd2').first(), - ("null", None) - ) - - -__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest', 'JSONTest', - 'DateTest', 'DateTimeTest', 'TextTest', - 'NumericTest', 'IntegerTest', - 'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest', - 'TimeMicrosecondsTest', 'TimestampMicrosecondsTest', 'TimeTest', - 'DateTimeMicrosecondsTest', - 'DateHistoricTest', 'StringTest', 'BooleanTest') + cast( + self.tables.data_table.c.data, + String(convert_unicode="force"), + ), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), + ) + + +__all__ = ( + "UnicodeVarcharTest", + "UnicodeTextTest", + "JSONTest", + "DateTest", + "DateTimeTest", + "TextTest", + "NumericTest", + "IntegerTest", + "DateTimeHistoricTest", + "DateTimeCoercedToDateTimeTest", + "TimeMicrosecondsTest", + "TimestampMicrosecondsTest", + "TimeTest", + "DateTimeMicrosecondsTest", + "DateHistoricTest", + "StringTest", + "BooleanTest", +) diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index e4c61e74a..b232c3a78 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -6,15 +6,17 @@ from ..schema import Table, Column class SimpleUpdateDeleteTest(fixtures.TablesTest): - run_deletes = 'each' + run_deletes = "each" __backend__ = True @classmethod def define_tables(cls, metadata): - Table('plain_pk', metadata, - Column('id', Integer, primary_key=True), - Column('data', String(50)) - ) + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) @classmethod def insert_data(cls): @@ -24,40 +26,29 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): {"id": 1, "data": "d1"}, {"id": 2, "data": "d2"}, {"id": 3, "data": "d3"}, - ] + ], ) def test_update(self): t = self.tables.plain_pk - r = config.db.execute( - t.update().where(t.c.id == 2), - data="d2_new" - ) + r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new") assert not r.is_insert assert not r.returns_rows eq_( config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [ - (1, "d1"), - (2, "d2_new"), - (3, "d3") - ] + [(1, "d1"), (2, "d2_new"), (3, "d3")], ) def test_delete(self): t = self.tables.plain_pk - r = config.db.execute( - t.delete().where(t.c.id == 2) - ) + r = config.db.execute(t.delete().where(t.c.id == 2)) assert not r.is_insert assert not r.returns_rows eq_( config.db.execute(t.select().order_by(t.c.id)).fetchall(), - [ - (1, "d1"), - (3, "d3") - ] + [(1, "d1"), (3, "d3")], ) -__all__ = ('SimpleUpdateDeleteTest', ) + +__all__ = ("SimpleUpdateDeleteTest",) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 409d3bda5..5b015d214 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -14,6 +14,7 @@ import sys import types if jython: + def jython_gc_collect(*args): """aggressive gc.collect for tests.""" gc.collect() @@ -25,9 +26,11 @@ if jython: # "lazy" gc, for VM's that don't GC on refcount == 0 gc_collect = lazy_gc = jython_gc_collect elif pypy: + def pypy_gc_collect(*args): gc.collect() gc.collect() + gc_collect = lazy_gc = pypy_gc_collect else: # assume CPython - straight gc.collect, lazy_gc() is a pass @@ -42,11 +45,13 @@ def picklers(): if py2k: try: import cPickle + picklers.add(cPickle) except ImportError: pass import pickle + picklers.add(pickle) # yes, this thing needs this much testing @@ -60,9 +65,9 @@ def round_decimal(value, prec): return round(value, prec) # can also use shift() here but that is 2.6 only - return (value * decimal.Decimal("1" + "0" * prec) - ).to_integral(decimal.ROUND_FLOOR) / \ - pow(10, prec) + return (value * decimal.Decimal("1" + "0" * prec)).to_integral( + decimal.ROUND_FLOOR + ) / pow(10, prec) class RandomSet(set): @@ -137,8 +142,9 @@ def function_named(fn, name): try: fn.__name__ = name except TypeError: - fn = types.FunctionType(fn.__code__, fn.__globals__, name, - fn.__defaults__, fn.__closure__) + fn = types.FunctionType( + fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__ + ) return fn @@ -190,7 +196,7 @@ def provide_metadata(fn, *args, **kw): metadata = schema.MetaData(config.db) self = args[0] - prev_meta = getattr(self, 'metadata', None) + prev_meta = getattr(self, "metadata", None) self.metadata = metadata try: return fn(*args, **kw) @@ -213,8 +219,8 @@ def force_drop_names(*names): try: return fn(*args, **kw) finally: - drop_all_tables( - config.db, inspect(config.db), include_names=names) + drop_all_tables(config.db, inspect(config.db), include_names=names) + return go @@ -234,8 +240,13 @@ class adict(dict): def drop_all_tables(engine, inspector, schema=None, include_names=None): - from sqlalchemy import Column, Table, Integer, MetaData, \ - ForeignKeyConstraint + from sqlalchemy import ( + Column, + Table, + Integer, + MetaData, + ForeignKeyConstraint, + ) from sqlalchemy.schema import DropTable, DropConstraint if include_names is not None: @@ -243,30 +254,35 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None): with engine.connect() as conn: for tname, fkcs in reversed( - inspector.get_sorted_table_and_fkc_names(schema=schema)): + inspector.get_sorted_table_and_fkc_names(schema=schema) + ): if tname: if include_names is not None and tname not in include_names: continue - conn.execute(DropTable( - Table(tname, MetaData(), schema=schema) - )) + conn.execute( + DropTable(Table(tname, MetaData(), schema=schema)) + ) elif fkcs: if not engine.dialect.supports_alter: continue for tname, fkc in fkcs: - if include_names is not None and \ - tname not in include_names: + if ( + include_names is not None + and tname not in include_names + ): continue tb = Table( - tname, MetaData(), - Column('x', Integer), - Column('y', Integer), - schema=schema + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + conn.execute( + DropConstraint( + ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc) + ) ) - conn.execute(DropConstraint( - ForeignKeyConstraint( - [tb.c.x], [tb.c.y], name=fkc) - )) def teardown_events(event_cls): @@ -276,5 +292,5 @@ def teardown_events(event_cls): return fn(*arg, **kw) finally: event_cls._clear() - return decorate + return decorate diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index 46e7c54db..e0101b14d 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -15,17 +15,20 @@ from . import assertions def setup_filters(): """Set global warning behavior for the test suite.""" - warnings.filterwarnings('ignore', - category=sa_exc.SAPendingDeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) - warnings.filterwarnings('error', category=sa_exc.SAWarning) + warnings.filterwarnings( + "ignore", category=sa_exc.SAPendingDeprecationWarning + ) + warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning) + warnings.filterwarnings("error", category=sa_exc.SAWarning) # some selected deprecations... - warnings.filterwarnings('error', category=DeprecationWarning) + warnings.filterwarnings("error", category=DeprecationWarning) warnings.filterwarnings( - "ignore", category=DeprecationWarning, message=".*StopIteration") + "ignore", category=DeprecationWarning, message=".*StopIteration" + ) warnings.filterwarnings( - "ignore", category=DeprecationWarning, message=".*inspect.getargspec") + "ignore", category=DeprecationWarning, message=".*inspect.getargspec" + ) def assert_warnings(fn, warning_msgs, regex=False): @@ -36,6 +39,6 @@ def assert_warnings(fn, warning_msgs, regex=False): """ with assertions._expect_warnings( - sa_exc.SAWarning, warning_msgs, regex=regex): + sa_exc.SAWarning, warning_msgs, regex=regex + ): return fn() - |
