diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2021-07-13 15:15:33 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2021-07-13 15:15:33 +0000 |
| commit | 3f4ee382a71558d4dbc1d37a2bedcecdce3d5461 (patch) | |
| tree | c6ca6a07e7b1a7c5b9682de66277c04b0880d28c | |
| parent | b64ecb03a5411dd5f32e40ac564bec9a886d3672 (diff) | |
| parent | a0953bb7095dde805de8c13699b122767ed001b9 (diff) | |
| download | sqlalchemy-3f4ee382a71558d4dbc1d37a2bedcecdce3d5461.tar.gz | |
Merge "Adjust CTE recrusive col list to accommodate dupe col names"
| -rw-r--r-- | doc/build/changelog/unreleased_14/6710.rst | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 41 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 5 | ||||
| -rw-r--r-- | test/orm/test_core_compilation.py | 11 | ||||
| -rw-r--r-- | test/sql/test_compiler.py | 3 | ||||
| -rw-r--r-- | test/sql/test_cte.py | 148 | ||||
| -rw-r--r-- | test/sql/test_selectable.py | 6 |
7 files changed, 209 insertions, 13 deletions
diff --git a/doc/build/changelog/unreleased_14/6710.rst b/doc/build/changelog/unreleased_14/6710.rst new file mode 100644 index 000000000..32784e889 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6710.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 6710 + + Fixed issue in CTE constructs where a recursive CTE that referred to a + SELECT that has duplicate column names, which are typically deduplicated + using labeling logic in 1.4, would fail to refer to the deduplicated label + name correctly within the WITH clause. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 581dad4fb..b9f55b746 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1311,6 +1311,9 @@ class SQLCompiler(Compiled): def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + def visit_select_statement_grouping(self, grouping, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + def visit_label_reference( self, element, within_columns_clause=False, **kwargs ): @@ -2562,17 +2565,29 @@ class SQLCompiler(Compiled): col_source = cte.element.selects[0] else: assert False, "cte should only be against SelectBase" + + # TODO: can we get at the .columns_plus_names collection + # that is already (or will be?) generated for the SELECT + # rather than calling twice? recur_cols = [ - c - for c in util.unique_list( - col_source._all_selected_columns - ) - if c is not None + # TODO: proxy_name is not technically safe, + # see test_cte-> + # test_with_recursive_no_name_currently_buggy. not + # clear what should be done with such a case + fallback_label_name or proxy_name + for ( + _, + proxy_name, + fallback_label_name, + c, + repeated, + ) in (col_source._generate_columns_plus_names(True)) + if not repeated ] text += "(%s)" % ( ", ".join( - self.preparer.format_column( + self.preparer.format_label_name( ident, anon_map=self.anon_map ) for ident in recur_cols @@ -5124,6 +5139,20 @@ class IdentifierPreparer(object): return self.quote(name) + def format_label_name( + self, + name, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + return self.quote(name) + def format_column( self, column, diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 30a613089..b6cf7f55e 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3139,7 +3139,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): """ - __visit_name__ = "grouping" + __visit_name__ = "select_statement_grouping" _traverse_internals = [("element", InternalTraversal.dp_clauseelement)] _is_select_container = True @@ -3173,6 +3173,9 @@ class SelectStatementGrouping(GroupedElement, SelectBase): def self_group(self, against=None): return self + def _generate_columns_plus_names(self, anon_for_dupe_key): + return self.element._generate_columns_plus_names(anon_for_dupe_key) + def _generate_fromclause_column_proxies(self, subquery): self.element._generate_fromclause_column_proxies(subquery) diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 5f25b56e8..e730d9097 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -627,8 +627,10 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select(u_alias), - "SELECT anon_1.id FROM ((SELECT users.name, users.id FROM users " - "WHERE users.id = :id_1 UNION SELECT users.name, users.id " + "SELECT anon_1.id FROM ((SELECT users.name AS name, " + "users.id AS id FROM users " + "WHERE users.id = :id_1 UNION SELECT users.name AS name, " + "users.id AS id " "FROM users WHERE users.id = :id_2) " "UNION SELECT users.name AS name, users.id AS id " "FROM users WHERE users.id = :id_3) AS anon_1", @@ -656,8 +658,9 @@ class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): self.assert_compile( select(u_alias).options(undefer(u_alias.name)), "SELECT anon_1.name, anon_1.id FROM " - "((SELECT users.name, users.id FROM users " - "WHERE users.id = :id_1 UNION SELECT users.name, users.id " + "((SELECT users.name AS name, users.id AS id FROM users " + "WHERE users.id = :id_1 UNION SELECT users.name AS name, " + "users.id AS id " "FROM users WHERE users.id = :id_2) " "UNION SELECT users.name AS name, users.id AS id " "FROM users WHERE users.id = :id_3) AS anon_1", diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index f2c1e004d..40faab486 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -892,6 +892,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "WITH RECURSIVE (colnames)" part. This test shows that this isn't correct when keys are present. + See also test_cte -> + test_wrecur_ovlp_lbls_plus_dupes_separate_keys_use_labels + """ m = MetaData() foo = Table( diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index e8a8a3150..f1d27aa8f 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,4 +1,9 @@ +from sqlalchemy import Column from sqlalchemy import delete +from sqlalchemy import Integer +from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL +from sqlalchemy import MetaData +from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import update @@ -495,6 +500,149 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): s.compile, ) + def test_with_recursive_no_name_currently_buggy(self): + s1 = select(1) + c1 = s1.cte(name="cte1", recursive=True) + + # this is nonsensical at the moment + self.assert_compile( + select(c1), + 'WITH RECURSIVE cte1("1") AS (SELECT 1) SELECT cte1.1 FROM cte1', + ) + + # however, so is subquery, which is worse as it isn't even trying + # to quote "1" as a label + self.assert_compile( + select(s1.subquery()), "SELECT anon_1.1 FROM (SELECT 1) AS anon_1" + ) + + def test_wrecur_dupe_col_names(self): + """test #6710""" + + manager = table("manager", column("id")) + employee = table("employee", column("id"), column("manager_id")) + + top_q = select(employee, manager).join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + + top_q = top_q.cte("cte", recursive=True) + + bottom_q = ( + select(employee, manager) + .join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + .join(top_q, top_q.c.id == employee.c.id) + ) + + rec_cte = select(top_q.union_all(bottom_q)) + self.assert_compile( + rec_cte, + "WITH RECURSIVE cte(id, manager_id, id_1) AS " + "(SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager " + "ON employee.manager_id = manager.id UNION ALL " + "SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager ON " + "employee.manager_id = manager.id " + "JOIN cte ON cte.id = employee.id) " + "SELECT cte.id, cte.manager_id, cte.id_1 FROM cte", + ) + + def test_wrecur_dupe_col_names_w_grouping(self): + """test #6710 + + by adding order_by() to the top query, the CTE will have + a compound select with the first element a SelectStatementGrouping + object, which we can test has the correct methods for the compiler + to call upon. + + """ + + manager = table("manager", column("id")) + employee = table("employee", column("id"), column("manager_id")) + + top_q = ( + select(employee, manager) + .join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + .order_by(employee.c.id) + .cte("cte", recursive=True) + ) + + bottom_q = ( + select(employee, manager) + .join_from( + employee, manager, employee.c.manager_id == manager.c.id + ) + .join(top_q, top_q.c.id == employee.c.id) + ) + + rec_cte = select(top_q.union_all(bottom_q)) + + self.assert_compile( + rec_cte, + "WITH RECURSIVE cte(id, manager_id, id_1) AS " + "((SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager " + "ON employee.manager_id = manager.id ORDER BY employee.id) " + "UNION ALL " + "SELECT employee.id AS id, employee.manager_id AS manager_id, " + "manager.id AS id_1 FROM employee JOIN manager ON " + "employee.manager_id = manager.id " + "JOIN cte ON cte.id = employee.id) " + "SELECT cte.id, cte.manager_id, cte.id_1 FROM cte", + ) + + def test_wrecur_ovlp_lbls_plus_dupes_separate_keys_use_labels(self): + """test a condition related to #6710. + + also see test_compiler-> + test_overlapping_labels_plus_dupes_separate_keys_use_labels + + for a non cte form of this test. + + """ + + m = MetaData() + foo = Table( + "foo", + m, + Column("id", Integer), + Column("bar_id", Integer, key="bb"), + ) + foo_bar = Table("foo_bar", m, Column("id", Integer, key="bb")) + + stmt = select( + foo.c.id, + foo.c.bb, + foo_bar.c.bb, + foo.c.bb, + foo.c.id, + foo.c.bb, + foo_bar.c.bb, + foo_bar.c.bb, + ).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + cte = stmt.cte(recursive=True) + + self.assert_compile( + select(cte), + "WITH RECURSIVE anon_1(foo_id, foo_bar_id, foo_bar_id_1) AS " + "(SELECT foo.id AS foo_id, foo.bar_id AS foo_bar_id, " + "foo_bar.id AS foo_bar_id_1, foo.bar_id AS foo_bar_id__1, " + "foo.id AS foo_id__1, foo.bar_id AS foo_bar_id__1, " + "foo_bar.id AS foo_bar_id__2, foo_bar.id AS foo_bar_id__2 " + "FROM foo, foo_bar) " + "SELECT anon_1.foo_id, anon_1.foo_bar_id, anon_1.foo_bar_id_1, " + "anon_1.foo_bar_id AS foo_bar_id_2, anon_1.foo_id AS foo_id_1, " + "anon_1.foo_bar_id AS foo_bar_id_3, " + "anon_1.foo_bar_id_1 AS foo_bar_id_1_1, " + "anon_1.foo_bar_id_1 AS foo_bar_id_1_2 FROM anon_1", + ) + def test_union(self): orders = table("orders", column("region"), column("amount")) diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index be894d239..cfdf4ad02 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -998,9 +998,11 @@ class SelectableTest( self.assert_compile( stmt, "SELECT anon_1.col1, anon_1.col2, anon_1.col1_1 FROM " - "((SELECT table1.col1, table1.col2, table2.col1 AS col1_1 " + "((SELECT table1.col1 AS col1, table1.col2 AS col2, table2.col1 " + "AS col1_1 " "FROM table1, table2 LIMIT :param_1) UNION " - "(SELECT table2.col1, table2.col2, table2.col3 FROM table2 " + "(SELECT table2.col1 AS col1, table2.col2 AS col2, " + "table2.col3 AS col3 FROM table2 " "LIMIT :param_2)) AS anon_1", ) |
