diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-07-10 11:00:49 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-07-10 11:00:49 -0400 |
| commit | 0e41673ed4e8551b892c058ffc6a607cf7aba71c (patch) | |
| tree | 6914091d7cbb58331d147242e6efa83bd1345424 /test/sql/test_cte.py | |
| parent | b297b40fca923a03e3c34094e5298d6524944c39 (diff) | |
| download | sqlalchemy-0e41673ed4e8551b892c058ffc6a607cf7aba71c.tar.gz | |
- [bug] Fixed more un-intuitivenesses in CTEs
which prevented referring to a CTE in a union
of itself without it being aliased.
CTEs now render uniquely
on name, rendering the outermost CTE of a given
name only - all other references are rendered
just as the name. This even includes other
CTE/SELECTs that refer to different versions
of the same CTE object, such as a SELECT
or a UNION ALL of that SELECT. We are
somewhat loosening the usual link between object
identity and lexical identity in this case.
A true name conflict between two unrelated
CTEs now raises an error.
Diffstat (limited to 'test/sql/test_cte.py')
| -rw-r--r-- | test/sql/test_cte.py | 179 |
1 files changed, 159 insertions, 20 deletions
diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 36f992a86..59b347ccd 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,15 +1,16 @@ from test.lib import fixtures -from test.lib.testing import AssertsCompiledSQL +from test.lib.testing import AssertsCompiledSQL, assert_raises_message from sqlalchemy.sql import table, column, select, func, literal from sqlalchemy.dialects import mssql from sqlalchemy.engine import default +from sqlalchemy.exc import CompileError class CTETest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' def test_nonrecursive(self): - orders = table('orders', + orders = table('orders', column('region'), column('amount'), column('product'), @@ -17,22 +18,22 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): ) regional_sales = select([ - orders.c.region, + orders.c.region, func.sum(orders.c.amount).label('total_sales') ]).group_by(orders.c.region).cte("regional_sales") top_regions = select([regional_sales.c.region]).\ where( - regional_sales.c.total_sales > + regional_sales.c.total_sales > select([ func.sum(regional_sales.c.total_sales)/10 ]) ).cte("top_regions") s = select([ - orders.c.region, - orders.c.product, - func.sum(orders.c.quantity).label("product_units"), + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), func.sum(orders.c.amount).label("product_sales") ]).where(orders.c.region.in_( select([top_regions.c.region]) @@ -59,15 +60,15 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): ) def test_recursive(self): - parts = table('parts', + parts = table('parts', column('part'), column('sub_part'), column('quantity'), ) included_parts = select([ - parts.c.sub_part, - parts.c.part, + parts.c.sub_part, + parts.c.part, parts.c.quantity]).\ where(parts.c.part=='our part').\ cte(recursive=True) @@ -76,19 +77,19 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): parts_alias = parts.alias() included_parts = included_parts.union( select([ - parts_alias.c.part, - parts_alias.c.sub_part, + parts_alias.c.part, + parts_alias.c.sub_part, parts_alias.c.quantity]).\ where(parts_alias.c.part==incl_alias.c.sub_part) ) s = select([ - included_parts.c.sub_part, + included_parts.c.sub_part, func.sum(included_parts.c.quantity).label('total_quantity')]).\ select_from(included_parts.join( parts,included_parts.c.part==parts.c.part)).\ group_by(included_parts.c.sub_part) - self.assert_compile(s, + self.assert_compile(s, "WITH RECURSIVE anon_1(sub_part, part, quantity) " "AS (SELECT parts.sub_part AS sub_part, parts.part " "AS part, parts.quantity AS quantity FROM parts " @@ -104,7 +105,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): # quick check that the "WITH RECURSIVE" varies per # dialect - self.assert_compile(s, + self.assert_compile(s, "WITH anon_1(sub_part, part, quantity) " "AS (SELECT parts.sub_part AS sub_part, parts.part " "AS part, parts.quantity AS quantity FROM parts " @@ -119,8 +120,146 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect=mssql.dialect() ) + def test_recursive_union_no_alias_one(self): + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ) + s2 = select([cte]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + "SELECT cte.x FROM cte" + ) + + + def test_recursive_union_no_alias_two(self): + """ + + pg's example: + + WITH RECURSIVE t(n) AS ( + VALUES (1) + UNION ALL + SELECT n+1 FROM t WHERE n < 100 + ) + SELECT sum(n) FROM t; + + """ + + # I know, this is the PG VALUES keyword, + # we're cheating here. also yes we need the SELECT, + # sorry PG. + t = select([func.values(1).label("n")]).cte("t", recursive=True) + t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)) + s = select([func.sum(t.c.n)]) + self.assert_compile(s, + "WITH RECURSIVE t(n) AS " + "(SELECT values(:values_1) AS n " + "UNION ALL SELECT t.n + :n_1 AS anon_1 " + "FROM t " + "WHERE t.n < :n_2) " + "SELECT sum(t.n) AS sum_1 FROM t" + ) + + def test_recursive_union_no_alias_three(self): + # like test one, but let's refer to the CTE + # in a sibling CTE. + + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + + # can't do it here... + #bar = select([cte]).cte('bar') + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ) + bar = select([cte]).cte('bar') + + s2 = select([cte, bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cte.x, bar.x FROM cte, bar" + ) + + + def test_recursive_union_no_alias_four(self): + # like test one and three, but let's refer + # previous version of "cte". here we test + # how the compiler resolves multiple instances + # of "cte". + + s1 = select([literal(0).label("x")]) + cte = s1.cte(name="cte", recursive=True) + + bar = select([cte]).cte('bar') + cte = cte.union_all( + select([cte.c.x + 1]).where(cte.c.x < 10) + ) + + # outer cte rendered first, then bar, which + # includes "inner" cte + s2 = select([cte, bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT cte.x, bar.x FROM cte, bar" + ) + + # bar rendered, only includes "inner" cte, + # "outer" cte isn't present + s2 = select([bar]) + self.assert_compile(s2, + "WITH RECURSIVE cte(x) AS " + "(SELECT :param_1 AS x), " + "bar AS (SELECT cte.x AS x FROM cte) " + "SELECT bar.x FROM bar" + ) + + # bar rendered, but then the "outer" + # cte is rendered. + s2 = select([bar, cte]) + self.assert_compile(s2, + "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " + "cte(x) AS " + "(SELECT :param_1 AS x UNION ALL " + "SELECT cte.x + :x_1 AS anon_1 " + "FROM cte WHERE cte.x < :x_2) " + + "SELECT bar.x, cte.x FROM bar, cte" + ) + + def test_conflicting_names(self): + """test a flat out name conflict.""" + + s1 = select([1]) + c1= s1.cte(name='cte1', recursive=True) + s2 = select([1]) + c2 = s2.cte(name='cte1', recursive=True) + + s = select([c1, c2]) + assert_raises_message( + CompileError, + "Multiple, unrelated CTEs found " + "with the same name: 'cte1'", + s.compile + ) + + + + def test_union(self): - orders = table('orders', + orders = table('orders', column('region'), column('amount'), ) @@ -135,7 +274,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): regional_sales.c.amount > 500 ) - self.assert_compile(s, + self.assert_compile(s, "WITH regional_sales AS " "(SELECT orders.region AS region, " "orders.amount AS amount FROM orders) " @@ -149,7 +288,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): regional_sales.c.amount < 300 ) ) - self.assert_compile(s, + self.assert_compile(s, "WITH regional_sales AS " "(SELECT orders.region AS region, " "orders.amount AS amount FROM orders) " @@ -160,7 +299,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "regional_sales.amount < :amount_2") def test_reserved_quote(self): - orders = table('orders', + orders = table('orders', column('order'), ) s = select([orders.c.order]).cte("regional_sales", recursive=True) @@ -174,7 +313,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): ) def test_positional_binds(self): - orders = table('orders', + orders = table('orders', column('order'), ) s = select([orders.c.order, literal("x")]).cte("regional_sales") |
