summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-01-06 01:14:26 -0500
committermike bayer <mike_mp@zzzcomputing.com>2019-01-06 17:34:50 +0000
commit1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch)
tree28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/testing/assertsql.py
parent404e69426b05a82d905cbb3ad33adafccddb00dd (diff)
downloadsqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits applied at all. The black run will format code consistently, however in some cases that are prevalent in SQLAlchemy code it produces too-long lines. The too-long lines will be resolved in the following commit that will resolve all remaining flake8 issues including shadowed builtins, long lines, import order, unused imports, duplicate imports, and docstring issues. Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
-rw-r--r--lib/sqlalchemy/testing/assertsql.py148
1 files changed, 80 insertions, 68 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 7a525589d..d8e924cb6 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -26,8 +26,10 @@ class AssertRule(object):
pass
def no_more_statements(self):
- assert False, 'All statements are complete, but pending '\
- 'assertion rules remain'
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
class SQLMatchRule(AssertRule):
@@ -44,12 +46,17 @@ class CursorSQL(SQLMatchRule):
def process_statement(self, execute_observed):
stmt = execute_observed.statements[0]
if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters):
- self.errormessage = \
- "Testing for exact SQL %s parameters %s received %s %s" % (
- self.statement, self.params,
- stmt.statement, stmt.parameters
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
)
+ )
else:
execute_observed.statements.pop(0)
self.is_consumed = True
@@ -58,23 +65,22 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
-
- def __init__(self, statement, params=None, dialect='default'):
+ def __init__(self, statement, params=None, dialect="default"):
self.statement = statement
self.params = params
self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- if self.dialect == 'default':
+ if self.dialect == "default":
return DefaultDialect()
else:
# ugh
- if self.dialect == 'postgresql':
- params = {'implicit_returning': True}
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
else:
params = {}
return url.URL(self.dialect).get_dialect()(**params)
@@ -86,36 +92,39 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
+ )
else:
- compiled = (
- context.compiled.statement.compile(
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- inline=context.compiled.inline,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
+ compiled = context.compiled.statement.compile(
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ inline=context.compiled.inline,
+ schema_translate_map=context.execution_options.get(
+ "schema_translate_map"
+ ),
)
- _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
parameters = execute_observed.parameters
if not parameters:
_received_parameters = [compiled.construct_params()]
else:
_received_parameters = [
- compiled.construct_params(m) for m in parameters]
+ compiled.construct_params(m) for m in parameters
+ ]
return _received_statement, _received_parameters
def process_statement(self, execute_observed):
context = execute_observed.context
- _received_statement, _received_parameters = \
- self._received_statement(execute_observed)
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
params = self._all_params(context)
equivalent = self._compare_sql(execute_observed, _received_statement)
@@ -132,8 +141,10 @@ class CompiledSQL(SQLMatchRule):
for param_key in param:
# a key in param did not match current
# 'received'
- if param_key not in received or \
- received[param_key] != param[param_key]:
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
break
else:
# all keys in param matched 'received';
@@ -153,8 +164,8 @@ class CompiledSQL(SQLMatchRule):
self.errormessage = None
else:
self.errormessage = self._failure_message(params) % {
- 'received_statement': _received_statement,
- 'received_parameters': _received_parameters
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
}
def _all_params(self, context):
@@ -171,11 +182,10 @@ class CompiledSQL(SQLMatchRule):
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement %r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.statement.replace('%', '%%'), expected_params
- )
+ "Testing for compiled statement %r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (self.statement.replace("%", "%%"), expected_params)
)
@@ -185,15 +195,13 @@ class RegexSQL(CompiledSQL):
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
- self.dialect = 'default'
+ self.dialect = "default"
def _failure_message(self, expected_params):
return (
- 'Testing for compiled statement ~%r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.orig_regex, expected_params
- )
+ "Testing for compiled statement ~%r partial params %r, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r" % (self.orig_regex, expected_params)
)
def _compare_sql(self, execute_observed, received_statement):
@@ -205,12 +213,13 @@ class DialectSQL(CompiledSQL):
return execute_observed.context.dialect
def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r'[\n\t]', '', real_stmt)
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
return received_stmt == stmt
def _received_statement(self, execute_observed):
- received_stmt, received_params = super(DialectSQL, self).\
- _received_statement(execute_observed)
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
# TODO: why do we need this part?
for real_stmt in execute_observed.statements:
@@ -219,34 +228,33 @@ class DialectSQL(CompiledSQL):
else:
raise AssertionError(
"Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt)
+ "statements actually invoked" % received_stmt
+ )
return received_stmt, execute_observed.context.compiled_parameters
def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
+ stmt = re.sub(r"[\n\t]", "", self.statement)
# convert our comparison statement to have the
# paramstyle of the received
paramstyle = execute_observed.context.dialect.paramstyle
- if paramstyle == 'pyformat':
- stmt = re.sub(
- r':([\w_]+)', r"%(\1)s", stmt)
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
# positional params
repl = None
- if paramstyle == 'qmark':
+ if paramstyle == "qmark":
repl = "?"
- elif paramstyle == 'format':
+ elif paramstyle == "format":
repl = r"%s"
- elif paramstyle == 'numeric':
+ elif paramstyle == "numeric":
repl = None
- stmt = re.sub(r':([\w_]+)', repl, stmt)
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
return received_statement == stmt
class CountStatements(AssertRule):
-
def __init__(self, count):
self.count = count
self._statement_count = 0
@@ -256,12 +264,13 @@ class CountStatements(AssertRule):
def no_more_statements(self):
if self.count != self._statement_count:
- assert False, 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
class AllOf(AssertRule):
-
def __init__(self, *rules):
self.rules = set(rules)
@@ -283,7 +292,6 @@ class AllOf(AssertRule):
class EachOf(AssertRule):
-
def __init__(self, *rules):
self.rules = list(rules)
@@ -309,7 +317,6 @@ class EachOf(AssertRule):
class Or(AllOf):
-
def process_statement(self, execute_observed):
for rule in self.rules:
rule.process_statement(execute_observed)
@@ -331,7 +338,8 @@ class SQLExecuteObserved(object):
class SQLCursorExecuteObserved(
collections.namedtuple(
"SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"])
+ ["statement", "parameters", "context", "executemany"],
+ )
):
pass
@@ -374,21 +382,25 @@ def assert_engine(engine):
orig[:] = clauseelement, multiparams, params
@event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(conn, cursor, statement, parameters,
- context, executemany):
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
if not context:
return
# then grab real cursor statements and associate them all
# around a single context
- if asserter.accumulated and \
- asserter.accumulated[-1].context is context:
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
obs = asserter.accumulated[-1]
else:
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
asserter.accumulated.append(obs)
obs.statements.append(
SQLCursorExecuteObserved(
- statement, parameters, context, executemany)
+ statement, parameters, context, executemany
+ )
)
try: