diff options
author | Alessio Bogon <youtux@gmail.com> | 2019-09-15 11:12:24 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-01-22 11:31:23 -0500 |
commit | 3809a5ecfe785cecbc9d91a8e4e4558e3839c694 (patch) | |
tree | dbde348bd6673de8acc685f82884f78f6b9f8f67 /lib/sqlalchemy/sql/compiler.py | |
parent | d3ad35838ede0713d073bbbf78d1bca511806059 (diff) | |
download | sqlalchemy-3809a5ecfe785cecbc9d91a8e4e4558e3839c694.tar.gz |
Query linter option
Added "from linting" as a built-in feature to the SQL compiler. This
allows the compiler to maintain graph of all the FROM clauses in a
particular SELECT statement, linked by criteria in either the WHERE
or in JOIN clauses that link these FROM clauses together. If any two
FROM clauses have no path between them, a warning is emitted that the
query may be producing a cartesian product. As the Core expression
language as well as the ORM are built on an "implicit FROMs" model where
a particular FROM clause is automatically added if any part of the query
refers to it, it is easy for this to happen inadvertently and it is
hoped that the new feature helps with this issue.
The original recipe is from:
https://github.com/sqlalchemy/sqlalchemy/wiki/FromLinter
The linter is now enabled for all tests in the test suite as well.
This has necessitated that a lot of the queries be adjusted to
not include cartesian products. Part of the rationale for the
linter to not be enabled for statement compilation only was to reduce
the need for adjustment for the many test case statements throughout
the test suite that are not real-world statements.
This gerrit is adapted from Ib5946e57c9dba6da428c4d1dee6760b3e978dda0.
Fixes: #4737
Change-Id: Ic91fd9774379f895d021c3ad564db6062299211c
Closes: #4830
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4830
Pull-request-sha: f8a21aa6262d1bcc9ff0d11a2616e41fba97a47a
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 176 |
1 files changed, 163 insertions, 13 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8499484f3..ed463ebe3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -41,7 +41,6 @@ from .base import NO_ARG from .. import exc from .. import util - RESERVED_WORDS = set( [ "all", @@ -270,6 +269,89 @@ ExpandedState = collections.namedtuple( ) +NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + +COLLECT_CARTESIAN_PRODUCTS = util.symbol( + "COLLECT_CARTESIAN_PRODUCTS", + "Collect data on FROMs and cartesian products and gather " + "into 'self.from_linter'", + canonical=1, +) + +WARN_LINTING = util.symbol( + "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 +) + +FROM_LINTING = util.symbol( + "FROM_LINTING", + "Warn for cartesian products; " + "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", + canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, +) + + +class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + def lint(self, start=None): + froms = self.froms + if not froms: + return None, None + + edges = set(self.edges) + the_rest = set(froms) + + if start is not None: + start_with = start + the_rest.remove(start_with) + else: + start_with = the_rest.pop() + + stack = collections.deque([start_with]) + + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + + # comparison of nodes in edges here is based on hash equality, as + # there are "annotated" elements that match the non-annotated ones. + # to remove the need for in-python hash() calls, use native + # containment routines (e.g. "node in edge", "edge.index(node)") + to_remove = {edge for edge in edges if node in edge} + + # appendleft the node in each edge that is not + # the one that matched. + stack.extendleft(edge[not edge.index(node)] for edge in to_remove) + edges.difference_update(to_remove) + + # FROMS left over? boom + if the_rest: + return the_rest, start_with + else: + return None, None + + def warn(self): + the_rest, start_with = self.lint() + + # FROMS left over? boom + if the_rest: + + froms = the_rest + if froms: + template = ( + "SELECT statement has a cartesian product between " + "FROM element(s) {froms} and " + 'FROM element "{start}". Apply join condition(s) ' + "between each element to resolve." + ) + froms_str = ", ".join( + '"{elem}"'.format(elem=self.froms[from_]) + for from_ in froms + ) + message = template.format( + froms=froms_str, start=self.froms[start_with] + ) + util.warn(message) + + class Compiled(object): """Represent a compiled SQL or DDL expression. @@ -568,7 +650,13 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () def __init__( - self, dialect, statement, column_keys=None, inline=False, **kwargs + self, + dialect, + statement, + column_keys=None, + inline=False, + linting=NO_LINTING, + **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -592,6 +680,8 @@ class SQLCompiler(Compiled): # execute) self.inline = inline or getattr(statement, "inline", False) + self.linting = linting + # a dictionary of bind parameter keys to BindParameter # instances. self.binds = {} @@ -1547,9 +1637,21 @@ class SQLCompiler(Compiled): return to_update, replacement_expression def visit_binary( - self, binary, override_operator=None, eager_grouping=False, **kw + self, + binary, + override_operator=None, + eager_grouping=False, + from_linter=None, + **kw ): + if from_linter and operators.is_comparison(binary.operator): + from_linter.edges.update( + itertools.product( + binary.left._from_objects, binary.right._from_objects + ) + ) + # don't allow "? = ?" to render if ( self.ansi_bind_rules @@ -1568,7 +1670,9 @@ class SQLCompiler(Compiled): except KeyError: raise exc.UnsupportedCompilationError(self, operator_) else: - return self._generate_generic_binary(binary, opstring, **kw) + return self._generate_generic_binary( + binary, opstring, from_linter=from_linter, **kw + ) def visit_function_as_comparison_op_binary(self, element, operator, **kw): return self.process(element.sql_function, **kw) @@ -1916,6 +2020,7 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, visiting_cte=None, + from_linter=None, **kwargs ): self._init_cte_state() @@ -2021,6 +2126,9 @@ class SQLCompiler(Compiled): self.ctes[cte] = text if asfrom: + if from_linter: + from_linter.froms[cte] = cte_name + if not is_new_cte and embedded_in_current_named_cte: return self.preparer.format_alias(cte, cte_name) @@ -2043,6 +2151,7 @@ class SQLCompiler(Compiled): subquery=False, lateral=False, enclosing_alias=None, + from_linter=None, **kwargs ): if enclosing_alias is not None and enclosing_alias.element is alias: @@ -2071,6 +2180,9 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: + if from_linter: + from_linter.froms[alias] = alias_name + inner = alias.element._compiler_dispatch( self, asfrom=True, lateral=lateral, **kwargs ) @@ -2284,6 +2396,7 @@ class SQLCompiler(Compiled): compound_index=0, select_wraps_for=None, lateral=False, + from_linter=None, **kwargs ): @@ -2373,7 +2486,7 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs + text, select, inner_columns, froms, byfrom, toplevel, kwargs ) if select._statement_hints: @@ -2465,10 +2578,17 @@ class SQLCompiler(Compiled): return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs + self, text, select, inner_columns, froms, byfrom, toplevel, kwargs ): text += ", ".join(inner_columns) + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + if froms: text += " \nFROM " @@ -2476,7 +2596,11 @@ class SQLCompiler(Compiled): text += ", ".join( [ f._compiler_dispatch( - self, asfrom=True, fromhints=byfrom, **kwargs + self, + asfrom=True, + fromhints=byfrom, + from_linter=from_linter, + **kwargs ) for f in froms ] @@ -2484,7 +2608,12 @@ class SQLCompiler(Compiled): else: text += ", ".join( [ - f._compiler_dispatch(self, asfrom=True, **kwargs) + f._compiler_dispatch( + self, + asfrom=True, + from_linter=from_linter, + **kwargs + ) for f in froms ] ) @@ -2492,10 +2621,18 @@ class SQLCompiler(Compiled): text += self.default_from() if select._whereclause is not None: - t = select._whereclause._compiler_dispatch(self, **kwargs) + t = select._whereclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) if t: text += " \nWHERE " + t + if ( + self.linting & COLLECT_CARTESIAN_PRODUCTS + and self.linting & WARN_LINTING + ): + from_linter.warn() + if select._group_by_clause.clauses: text += self.group_by_clause(select, **kwargs) @@ -2597,8 +2734,12 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, use_schema=True, + from_linter=None, **kwargs ): + if from_linter: + from_linter.froms[table] = table.fullname + if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) @@ -2618,7 +2759,10 @@ class SQLCompiler(Compiled): else: return "" - def visit_join(self, join, asfrom=False, **kwargs): + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + if join.full: join_type = " FULL OUTER JOIN " elif join.isouter: @@ -2626,12 +2770,18 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + join_type - + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + " ON " # TODO: likely need asfrom=True here? - + join.onclause._compiler_dispatch(self, **kwargs) + + join.onclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) ) def _setup_crud_hints(self, stmt, table_text): |