summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorAlessio Bogon <youtux@gmail.com>2019-09-15 11:12:24 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-01-22 11:31:23 -0500
commit3809a5ecfe785cecbc9d91a8e4e4558e3839c694 (patch)
treedbde348bd6673de8acc685f82884f78f6b9f8f67 /lib/sqlalchemy
parentd3ad35838ede0713d073bbbf78d1bca511806059 (diff)
downloadsqlalchemy-3809a5ecfe785cecbc9d91a8e4e4558e3839c694.tar.gz
Query linter option
Added "from linting" as a built-in feature to the SQL compiler. This allows the compiler to maintain graph of all the FROM clauses in a particular SELECT statement, linked by criteria in either the WHERE or in JOIN clauses that link these FROM clauses together. If any two FROM clauses have no path between them, a warning is emitted that the query may be producing a cartesian product. As the Core expression language as well as the ORM are built on an "implicit FROMs" model where a particular FROM clause is automatically added if any part of the query refers to it, it is easy for this to happen inadvertently and it is hoped that the new feature helps with this issue. The original recipe is from: https://github.com/sqlalchemy/sqlalchemy/wiki/FromLinter The linter is now enabled for all tests in the test suite as well. This has necessitated that a lot of the queries be adjusted to not include cartesian products. Part of the rationale for the linter to not be enabled for statement compilation only was to reduce the need for adjustment for the many test case statements throughout the test suite that are not real-world statements. This gerrit is adapted from Ib5946e57c9dba6da428c4d1dee6760b3e978dda0. Fixes: #4737 Change-Id: Ic91fd9774379f895d021c3ad564db6062299211c Closes: #4830 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4830 Pull-request-sha: f8a21aa6262d1bcc9ff0d11a2616e41fba97a47a
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py15
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py13
-rw-r--r--lib/sqlalchemy/engine/base.py4
-rw-r--r--lib/sqlalchemy/engine/create.py16
-rw-r--r--lib/sqlalchemy/engine/default.py6
-rw-r--r--lib/sqlalchemy/sql/__init__.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py176
7 files changed, 211 insertions, 23 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 8241d951b..6e84c9da1 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1434,7 +1434,10 @@ class MySQLCompiler(compiler.SQLCompiler):
else:
return ""
- def visit_join(self, join, asfrom=False, **kwargs):
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
if join.full:
join_type = " FULL OUTER JOIN "
elif join.isouter:
@@ -1444,11 +1447,15 @@ class MySQLCompiler(compiler.SQLCompiler):
return "".join(
(
- self.process(join.left, asfrom=True, **kwargs),
+ self.process(
+ join.left, asfrom=True, from_linter=from_linter, **kwargs
+ ),
join_type,
- self.process(join.right, asfrom=True, **kwargs),
+ self.process(
+ join.right, asfrom=True, from_linter=from_linter, **kwargs
+ ),
" ON ",
- self.process(join.onclause, **kwargs),
+ self.process(join.onclause, from_linter=from_linter, **kwargs),
)
)
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 9cb25b934..87e0baa58 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -829,19 +829,24 @@ class OracleCompiler(compiler.SQLCompiler):
return " FROM DUAL"
- def visit_join(self, join, **kwargs):
+ def visit_join(self, join, from_linter=None, **kwargs):
if self.dialect.use_ansi:
- return compiler.SQLCompiler.visit_join(self, join, **kwargs)
+ return compiler.SQLCompiler.visit_join(
+ self, join, from_linter=from_linter, **kwargs
+ )
else:
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
kwargs["asfrom"] = True
if isinstance(join.right, expression.FromGrouping):
right = join.right.element
else:
right = join.right
return (
- self.process(join.left, **kwargs)
+ self.process(join.left, from_linter=from_linter, **kwargs)
+ ", "
- + self.process(right, **kwargs)
+ + self.process(right, from_linter=from_linter, **kwargs)
)
def _get_nonansi_join_whereclause(self, froms):
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 88558df5d..462e5f9ec 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -16,6 +16,7 @@ from .. import exc
from .. import inspection
from .. import log
from .. import util
+from ..sql import compiler
from ..sql import schema
from ..sql import util as sql_util
@@ -1083,6 +1084,8 @@ class Connection(Connectable):
schema_translate_map=self.schema_for_object
if not self.schema_for_object.is_default
else None,
+ linting=self.dialect.compiler_linting
+ | compiler.WARN_LINTING,
)
self._execution_options["compiled_cache"][key] = compiled_sql
else:
@@ -1093,6 +1096,7 @@ class Connection(Connectable):
schema_translate_map=self.schema_for_object
if not self.schema_for_object.is_default
else None,
+ linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
)
ret = self._execute_context(
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
index 58fe91c7e..5198c8cd6 100644
--- a/lib/sqlalchemy/engine/create.py
+++ b/lib/sqlalchemy/engine/create.py
@@ -13,6 +13,7 @@ from .. import event
from .. import exc
from .. import pool as poollib
from .. import util
+from ..sql import compiler
@util.deprecated_params(
@@ -142,6 +143,16 @@ def create_engine(url, **kwargs):
:param empty_in_strategy: No longer used; SQLAlchemy now uses
"empty set" behavior for IN in all cases.
+ :param enable_from_linting: defaults to True. Will emit a warning
+ if a given SELECT statement is found to have un-linked FROM elements
+ which would cause a cartesian product.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`change_4737`
+
:param encoding: Defaults to ``utf-8``. This is the string
encoding used by SQLAlchemy for string encode/decode
operations which occur within SQLAlchemy, **outside of
@@ -446,6 +457,11 @@ def create_engine(url, **kwargs):
dialect_args["dbapi"] = dbapi
+ dialect_args.setdefault("compiler_linting", compiler.NO_LINTING)
+ enable_from_linting = kwargs.pop("enable_from_linting", True)
+ if enable_from_linting:
+ dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS
+
for plugin in plugins:
plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 1c995f05f..378890444 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -31,7 +31,6 @@ from ..sql import expression
from ..sql import schema
from ..sql.elements import quoted_name
-
AUTOCOMMIT_REGEXP = re.compile(
r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE
)
@@ -214,6 +213,9 @@ class DefaultDialect(interfaces.Dialect):
supports_native_boolean=None,
max_identifier_length=None,
label_length=None,
+ # int() is because the @deprecated_params decorator cannot accommodate
+ # the direct reference to the "NO_LINTING" object
+ compiler_linting=int(compiler.NO_LINTING),
**kwargs
):
@@ -249,7 +251,7 @@ class DefaultDialect(interfaces.Dialect):
self._user_defined_max_identifier_length
)
self.label_length = label_length
-
+ self.compiler_linting = compiler_linting
if self.description_encoding == "use_encoding":
self._description_decoder = (
processors.to_unicode_processor_factory
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 6554faaa0..488717041 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -5,6 +5,10 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+from .compiler import COLLECT_CARTESIAN_PRODUCTS # noqa
+from .compiler import FROM_LINTING # noqa
+from .compiler import NO_LINTING # noqa
+from .compiler import WARN_LINTING # noqa
from .expression import Alias # noqa
from .expression import alias # noqa
from .expression import all_ # noqa
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8499484f3..ed463ebe3 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -41,7 +41,6 @@ from .base import NO_ARG
from .. import exc
from .. import util
-
RESERVED_WORDS = set(
[
"all",
@@ -270,6 +269,89 @@ ExpandedState = collections.namedtuple(
)
+NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
+
+COLLECT_CARTESIAN_PRODUCTS = util.symbol(
+ "COLLECT_CARTESIAN_PRODUCTS",
+ "Collect data on FROMs and cartesian products and gather "
+ "into 'self.from_linter'",
+ canonical=1,
+)
+
+WARN_LINTING = util.symbol(
+ "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
+)
+
+FROM_LINTING = util.symbol(
+ "FROM_LINTING",
+ "Warn for cartesian products; "
+ "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
+ canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
+)
+
+
+class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
+ def lint(self, start=None):
+ froms = self.froms
+ if not froms:
+ return None, None
+
+ edges = set(self.edges)
+ the_rest = set(froms)
+
+ if start is not None:
+ start_with = start
+ the_rest.remove(start_with)
+ else:
+ start_with = the_rest.pop()
+
+ stack = collections.deque([start_with])
+
+ while stack and the_rest:
+ node = stack.popleft()
+ the_rest.discard(node)
+
+ # comparison of nodes in edges here is based on hash equality, as
+ # there are "annotated" elements that match the non-annotated ones.
+ # to remove the need for in-python hash() calls, use native
+ # containment routines (e.g. "node in edge", "edge.index(node)")
+ to_remove = {edge for edge in edges if node in edge}
+
+ # appendleft the node in each edge that is not
+ # the one that matched.
+ stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
+ edges.difference_update(to_remove)
+
+ # FROMS left over? boom
+ if the_rest:
+ return the_rest, start_with
+ else:
+ return None, None
+
+ def warn(self):
+ the_rest, start_with = self.lint()
+
+ # FROMS left over? boom
+ if the_rest:
+
+ froms = the_rest
+ if froms:
+ template = (
+ "SELECT statement has a cartesian product between "
+ "FROM element(s) {froms} and "
+ 'FROM element "{start}". Apply join condition(s) '
+ "between each element to resolve."
+ )
+ froms_str = ", ".join(
+ '"{elem}"'.format(elem=self.froms[from_])
+ for from_ in froms
+ )
+ message = template.format(
+ froms=froms_str, start=self.froms[start_with]
+ )
+ util.warn(message)
+
+
class Compiled(object):
"""Represent a compiled SQL or DDL expression.
@@ -568,7 +650,13 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
def __init__(
- self, dialect, statement, column_keys=None, inline=False, **kwargs
+ self,
+ dialect,
+ statement,
+ column_keys=None,
+ inline=False,
+ linting=NO_LINTING,
+ **kwargs
):
"""Construct a new :class:`.SQLCompiler` object.
@@ -592,6 +680,8 @@ class SQLCompiler(Compiled):
# execute)
self.inline = inline or getattr(statement, "inline", False)
+ self.linting = linting
+
# a dictionary of bind parameter keys to BindParameter
# instances.
self.binds = {}
@@ -1547,9 +1637,21 @@ class SQLCompiler(Compiled):
return to_update, replacement_expression
def visit_binary(
- self, binary, override_operator=None, eager_grouping=False, **kw
+ self,
+ binary,
+ override_operator=None,
+ eager_grouping=False,
+ from_linter=None,
+ **kw
):
+ if from_linter and operators.is_comparison(binary.operator):
+ from_linter.edges.update(
+ itertools.product(
+ binary.left._from_objects, binary.right._from_objects
+ )
+ )
+
# don't allow "? = ?" to render
if (
self.ansi_bind_rules
@@ -1568,7 +1670,9 @@ class SQLCompiler(Compiled):
except KeyError:
raise exc.UnsupportedCompilationError(self, operator_)
else:
- return self._generate_generic_binary(binary, opstring, **kw)
+ return self._generate_generic_binary(
+ binary, opstring, from_linter=from_linter, **kw
+ )
def visit_function_as_comparison_op_binary(self, element, operator, **kw):
return self.process(element.sql_function, **kw)
@@ -1916,6 +2020,7 @@ class SQLCompiler(Compiled):
ashint=False,
fromhints=None,
visiting_cte=None,
+ from_linter=None,
**kwargs
):
self._init_cte_state()
@@ -2021,6 +2126,9 @@ class SQLCompiler(Compiled):
self.ctes[cte] = text
if asfrom:
+ if from_linter:
+ from_linter.froms[cte] = cte_name
+
if not is_new_cte and embedded_in_current_named_cte:
return self.preparer.format_alias(cte, cte_name)
@@ -2043,6 +2151,7 @@ class SQLCompiler(Compiled):
subquery=False,
lateral=False,
enclosing_alias=None,
+ from_linter=None,
**kwargs
):
if enclosing_alias is not None and enclosing_alias.element is alias:
@@ -2071,6 +2180,9 @@ class SQLCompiler(Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
+ if from_linter:
+ from_linter.froms[alias] = alias_name
+
inner = alias.element._compiler_dispatch(
self, asfrom=True, lateral=lateral, **kwargs
)
@@ -2284,6 +2396,7 @@ class SQLCompiler(Compiled):
compound_index=0,
select_wraps_for=None,
lateral=False,
+ from_linter=None,
**kwargs
):
@@ -2373,7 +2486,7 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, kwargs
+ text, select, inner_columns, froms, byfrom, toplevel, kwargs
)
if select._statement_hints:
@@ -2465,10 +2578,17 @@ class SQLCompiler(Compiled):
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, kwargs
+ self, text, select, inner_columns, froms, byfrom, toplevel, kwargs
):
text += ", ".join(inner_columns)
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+
if froms:
text += " \nFROM "
@@ -2476,7 +2596,11 @@ class SQLCompiler(Compiled):
text += ", ".join(
[
f._compiler_dispatch(
- self, asfrom=True, fromhints=byfrom, **kwargs
+ self,
+ asfrom=True,
+ fromhints=byfrom,
+ from_linter=from_linter,
+ **kwargs
)
for f in froms
]
@@ -2484,7 +2608,12 @@ class SQLCompiler(Compiled):
else:
text += ", ".join(
[
- f._compiler_dispatch(self, asfrom=True, **kwargs)
+ f._compiler_dispatch(
+ self,
+ asfrom=True,
+ from_linter=from_linter,
+ **kwargs
+ )
for f in froms
]
)
@@ -2492,10 +2621,18 @@ class SQLCompiler(Compiled):
text += self.default_from()
if select._whereclause is not None:
- t = select._whereclause._compiler_dispatch(self, **kwargs)
+ t = select._whereclause._compiler_dispatch(
+ self, from_linter=from_linter, **kwargs
+ )
if t:
text += " \nWHERE " + t
+ if (
+ self.linting & COLLECT_CARTESIAN_PRODUCTS
+ and self.linting & WARN_LINTING
+ ):
+ from_linter.warn()
+
if select._group_by_clause.clauses:
text += self.group_by_clause(select, **kwargs)
@@ -2597,8 +2734,12 @@ class SQLCompiler(Compiled):
ashint=False,
fromhints=None,
use_schema=True,
+ from_linter=None,
**kwargs
):
+ if from_linter:
+ from_linter.froms[table] = table.fullname
+
if asfrom or ashint:
effective_schema = self.preparer.schema_for_object(table)
@@ -2618,7 +2759,10 @@ class SQLCompiler(Compiled):
else:
return ""
- def visit_join(self, join, asfrom=False, **kwargs):
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
if join.full:
join_type = " FULL OUTER JOIN "
elif join.isouter:
@@ -2626,12 +2770,18 @@ class SQLCompiler(Compiled):
else:
join_type = " JOIN "
return (
- join.left._compiler_dispatch(self, asfrom=True, **kwargs)
+ join.left._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ join_type
- + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
+ + join.right._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ " ON "
# TODO: likely need asfrom=True here?
- + join.onclause._compiler_dispatch(self, **kwargs)
+ + join.onclause._compiler_dispatch(
+ self, from_linter=from_linter, **kwargs
+ )
)
def _setup_crud_hints(self, stmt, table_text):