diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
---|---|---|
committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/testing/assertsql.py | |
parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 148 |
1 files changed, 80 insertions, 68 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 7a525589d..d8e924cb6 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -26,8 +26,10 @@ class AssertRule(object): pass def no_more_statements(self): - assert False, 'All statements are complete, but pending '\ - 'assertion rules remain' + assert False, ( + "All statements are complete, but pending " + "assertion rules remain" + ) class SQLMatchRule(AssertRule): @@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule): def process_statement(self, execute_observed): stmt = execute_observed.statements[0] if self.statement != stmt.statement or ( - self.params is not None and self.params != stmt.parameters): - self.errormessage = \ - "Testing for exact SQL %s parameters %s received %s %s" % ( - self.statement, self.params, - stmt.statement, stmt.parameters + self.params is not None and self.params != stmt.parameters + ): + self.errormessage = ( + "Testing for exact SQL %s parameters %s received %s %s" + % ( + self.statement, + self.params, + stmt.statement, + stmt.parameters, ) + ) else: execute_observed.statements.pop(0) self.is_consumed = True @@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - - def __init__(self, statement, params=None, dialect='default'): + def __init__(self, statement, params=None, dialect="default"): self.statement = statement self.params = params self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): - if self.dialect == 'default': + if self.dialect == "default": return DefaultDialect() else: # ugh - if self.dialect == 'postgresql': - params = {'implicit_returning': True} + if self.dialect == "postgresql": + params = {"implicit_returning": True} else: params = {} return url.URL(self.dialect).get_dialect()(**params) @@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule): context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): - compiled = \ - context.compiled.statement.compile( - dialect=compare_dialect, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), + ) else: - compiled = ( - context.compiled.statement.compile( - dialect=compare_dialect, - column_keys=context.compiled.column_keys, - inline=context.compiled.inline, - schema_translate_map=context. - execution_options.get('schema_translate_map')) + compiled = context.compiled.statement.compile( + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + inline=context.compiled.inline, + schema_translate_map=context.execution_options.get( + "schema_translate_map" + ), ) - _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) + _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: _received_parameters = [compiled.construct_params()] else: _received_parameters = [ - compiled.construct_params(m) for m in parameters] + compiled.construct_params(m) for m in parameters + ] return _received_statement, _received_parameters def process_statement(self, execute_observed): context = execute_observed.context - _received_statement, _received_parameters = \ - self._received_statement(execute_observed) + _received_statement, _received_parameters = self._received_statement( + execute_observed + ) params = self._all_params(context) equivalent = self._compare_sql(execute_observed, _received_statement) @@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule): for param_key in param: # a key in param did not match current # 'received' - if param_key not in received or \ - received[param_key] != param[param_key]: + if ( + param_key not in received + or received[param_key] != param[param_key] + ): break else: # all keys in param matched 'received'; @@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule): self.errormessage = None else: self.errormessage = self._failure_message(params) % { - 'received_statement': _received_statement, - 'received_parameters': _received_parameters + "received_statement": _received_statement, + "received_parameters": _received_parameters, } def _all_params(self, context): @@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule): def _failure_message(self, expected_params): return ( - 'Testing for compiled statement %r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.statement.replace('%', '%%'), expected_params - ) + "Testing for compiled statement %r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" + % (self.statement.replace("%", "%%"), expected_params) ) @@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL): self.regex = re.compile(regex) self.orig_regex = regex self.params = params - self.dialect = 'default' + self.dialect = "default" def _failure_message(self, expected_params): return ( - 'Testing for compiled statement ~%r partial params %r, ' - 'received %%(received_statement)r with params ' - '%%(received_parameters)r' % ( - self.orig_regex, expected_params - ) + "Testing for compiled statement ~%r partial params %r, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" % (self.orig_regex, expected_params) ) def _compare_sql(self, execute_observed, received_statement): @@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL): return execute_observed.context.dialect def _compare_no_space(self, real_stmt, received_stmt): - stmt = re.sub(r'[\n\t]', '', real_stmt) + stmt = re.sub(r"[\n\t]", "", real_stmt) return received_stmt == stmt def _received_statement(self, execute_observed): - received_stmt, received_params = super(DialectSQL, self).\ - _received_statement(execute_observed) + received_stmt, received_params = super( + DialectSQL, self + )._received_statement(execute_observed) # TODO: why do we need this part? for real_stmt in execute_observed.statements: @@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL): else: raise AssertionError( "Can't locate compiled statement %r in list of " - "statements actually invoked" % received_stmt) + "statements actually invoked" % received_stmt + ) return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): - stmt = re.sub(r'[\n\t]', '', self.statement) + stmt = re.sub(r"[\n\t]", "", self.statement) # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle - if paramstyle == 'pyformat': - stmt = re.sub( - r':([\w_]+)', r"%(\1)s", stmt) + if paramstyle == "pyformat": + stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) else: # positional params repl = None - if paramstyle == 'qmark': + if paramstyle == "qmark": repl = "?" - elif paramstyle == 'format': + elif paramstyle == "format": repl = r"%s" - elif paramstyle == 'numeric': + elif paramstyle == "numeric": repl = None - stmt = re.sub(r':([\w_]+)', repl, stmt) + stmt = re.sub(r":([\w_]+)", repl, stmt) return received_statement == stmt class CountStatements(AssertRule): - def __init__(self, count): self.count = count self._statement_count = 0 @@ -256,12 +264,13 @@ class CountStatements(AssertRule): def no_more_statements(self): if self.count != self._statement_count: - assert False, 'desired statement count %d does not match %d' \ - % (self.count, self._statement_count) + assert False, "desired statement count %d does not match %d" % ( + self.count, + self._statement_count, + ) class AllOf(AssertRule): - def __init__(self, *rules): self.rules = set(rules) @@ -283,7 +292,6 @@ class AllOf(AssertRule): class EachOf(AssertRule): - def __init__(self, *rules): self.rules = list(rules) @@ -309,7 +317,6 @@ class EachOf(AssertRule): class Or(AllOf): - def process_statement(self, execute_observed): for rule in self.rules: rule.process_statement(execute_observed) @@ -331,7 +338,8 @@ class SQLExecuteObserved(object): class SQLCursorExecuteObserved( collections.namedtuple( "SQLCursorExecuteObserved", - ["statement", "parameters", "context", "executemany"]) + ["statement", "parameters", "context", "executemany"], + ) ): pass @@ -374,21 +382,25 @@ def assert_engine(engine): orig[:] = clauseelement, multiparams, params @event.listens_for(engine, "after_cursor_execute") - def cursor_execute(conn, cursor, statement, parameters, - context, executemany): + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): if not context: return # then grab real cursor statements and associate them all # around a single context - if asserter.accumulated and \ - asserter.accumulated[-1].context is context: + if ( + asserter.accumulated + and asserter.accumulated[-1].context is context + ): obs = asserter.accumulated[-1] else: obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) asserter.accumulated.append(obs) obs.statements.append( SQLCursorExecuteObserved( - statement, parameters, context, executemany) + statement, parameters, context, executemany + ) ) try: |