summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-07-27 17:05:01 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2013-07-27 17:05:01 -0400
commitd64b09b15c703d7e100931818976822256f6beda (patch)
treefb46d2b6aa486f47f7e949ba14f145c00aa0f7c5
parenta6adc0a53ebfeeea23ec63655629c5b6831a77d0 (diff)
downloadsqlalchemy-d64b09b15c703d7e100931818976822256f6beda.tar.gz
- fix issue in join rewriting whereby we need to ensure the .key and .name
are transferred correctly for when .key is present; tests have been enhanced to test this condition for render, result map construction, statement execution. [ticket:2790]
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--test/sql/test_join_rewriting.py87
2 files changed, 86 insertions, 5 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index a5f545de9..72bebca3d 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1162,7 +1162,8 @@ class SQLCompiler(engine.Compiled):
use_labels=True).alias()
for c in selectable.c:
- c._label = c._key_label = c.name
+ c._key_label = c.key
+ c._label = c.name
translate_dict = dict(
zip(right.element.c, selectable.c)
)
@@ -1199,6 +1200,7 @@ class SQLCompiler(engine.Compiled):
def _transform_result_map_for_nested_joins(self, select, transformed_select):
inner_col = dict((c._key_label, c) for
c in transformed_select.inner_columns)
+
d = dict(
(inner_col[c._key_label], c)
for c in select.inner_columns
diff --git a/test/sql/test_join_rewriting.py b/test/sql/test_join_rewriting.py
index 5a9bdd1d3..3c670199c 100644
--- a/test/sql/test_join_rewriting.py
+++ b/test/sql/test_join_rewriting.py
@@ -1,10 +1,11 @@
from sqlalchemy import Table, Column, Integer, MetaData, ForeignKey, select
-from sqlalchemy.testing import fixtures, AssertsCompiledSQL
+from sqlalchemy.testing import fixtures, AssertsCompiledSQL, eq_
from sqlalchemy import util
from sqlalchemy.engine import default
from sqlalchemy import testing
+
m = MetaData()
a = Table('a', m,
@@ -30,6 +31,15 @@ e = Table('e', m,
Column('id', Integer, primary_key=True)
)
+b_key = Table('b_key', m,
+ Column('id', Integer, primary_key=True, key='bid'),
+ )
+
+a_to_b_key = Table('a_to_b_key', m,
+ Column('aid', Integer, ForeignKey('a.id')),
+ Column('bid', Integer, ForeignKey('b_key.bid')),
+ )
+
class _JoinRewriteTestBase(AssertsCompiledSQL):
def _test(self, s, assert_):
self.assert_compile(
@@ -38,7 +48,10 @@ class _JoinRewriteTestBase(AssertsCompiledSQL):
)
compiled = s.compile(dialect=self.__dialect__)
- for key, col in zip([c.key for c in s.c], s.inner_columns):
+
+ # column name should be in result map, as we never render
+ # .key in SQL
+ for key, col in zip([c.name for c in s.c], s.inner_columns):
key = key % compiled.anon_map
assert col in compiled.result_map[key][1]
@@ -60,6 +73,27 @@ class _JoinRewriteTestBase(AssertsCompiledSQL):
self._test(s, self._a_bc)
+ def test_a_bkeyassoc(self):
+ j1 = b_key.join(a_to_b_key)
+ j2 = a.join(j1)
+
+ s = select([a, b_key.c.bid], use_labels=True).\
+ select_from(j2)
+
+ self._test(s, self._a_bkeyassoc)
+
+ def test_a_bkeyassoc_aliased(self):
+ bkey_alias = b_key.alias()
+ a_to_b_key_alias = a_to_b_key.alias()
+
+ j1 = bkey_alias.join(a_to_b_key_alias)
+ j2 = a.join(j1)
+
+ s = select([a, bkey_alias.c.bid], use_labels=True).\
+ select_from(j2)
+
+ self._test(s, self._a_bkeyassoc_aliased)
+
def test_a__b_dc(self):
j1 = c.join(d)
j2 = b.join(j1)
@@ -94,6 +128,7 @@ class _JoinRewriteTestBase(AssertsCompiledSQL):
self._a_bc_comma_a1_selbc
)
+
class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase):
"""test rendering of each join with right-nested rewritten as
aliased SELECT statements.."""
@@ -149,6 +184,24 @@ class JoinRewriteTest(_JoinRewriteTestBase, fixtures.TestBase):
"ON a_1.id = anon_2.b_a_id ORDER BY anon_2.b_id"
)
+ _a_bkeyassoc = (
+ "SELECT a.id AS a_id, anon_1.b_key_id AS b_key_id "
+ "FROM a JOIN "
+ "(SELECT b_key.id AS b_key_id, a_to_b_key.aid AS a_to_b_key_aid, "
+ "a_to_b_key.bid AS a_to_b_key_bid FROM b_key "
+ "JOIN a_to_b_key ON b_key.id = a_to_b_key.bid) AS anon_1 "
+ "ON a.id = anon_1.a_to_b_key_aid"
+ )
+
+ _a_bkeyassoc_aliased = (
+ "SELECT a.id AS a_id, anon_1.b_key_1_id AS b_key_1_id "
+ "FROM a JOIN (SELECT b_key_1.id AS b_key_1_id, "
+ "a_to_b_key_1.aid AS a_to_b_key_1_aid, "
+ "a_to_b_key_1.bid AS a_to_b_key_1_bid FROM b_key AS b_key_1 "
+ "JOIN a_to_b_key AS a_to_b_key_1 ON b_key_1.id = a_to_b_key_1.bid) AS "
+ "anon_1 ON a.id = anon_1.a_to_b_key_1_aid"
+ )
+
class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase):
"""test rendering of each join with normal nesting."""
@util.classproperty
@@ -194,6 +247,19 @@ class JoinPlainTest(_JoinRewriteTestBase, fixtures.TestBase):
"ON a_1.id = anon_1.b_a_id ORDER BY anon_1.b_id"
)
+ _a_bkeyassoc = (
+ "SELECT a.id AS a_id, b_key.id AS b_key_id "
+ "FROM a JOIN "
+ "(b_key JOIN a_to_b_key ON b_key.id = a_to_b_key.bid) "
+ "ON a.id = a_to_b_key.aid"
+ )
+
+ _a_bkeyassoc_aliased = (
+ "SELECT a.id AS a_id, b_key_1.id AS b_key_1_id FROM a "
+ "JOIN (b_key AS b_key_1 JOIN a_to_b_key AS a_to_b_key_1 "
+ "ON b_key_1.id = a_to_b_key_1.bid) ON a.id = a_to_b_key_1.aid"
+ )
+
class JoinNoUseLabelsTest(_JoinRewriteTestBase, fixtures.TestBase):
@util.classproperty
def __dialect__(cls):
@@ -245,10 +311,21 @@ class JoinNoUseLabelsTest(_JoinRewriteTestBase, fixtures.TestBase):
"ON a_1.id = anon_1.b_a_id ORDER BY anon_1.b_id"
)
+ _a_bkeyassoc = (
+ "SELECT a.id, b_key.id FROM a JOIN (b_key JOIN a_to_b_key "
+ "ON b_key.id = a_to_b_key.bid) ON a.id = a_to_b_key.aid"
+ )
+
+ _a_bkeyassoc_aliased = (
+ "SELECT a.id, b_key_1.id FROM a JOIN (b_key AS b_key_1 "
+ "JOIN a_to_b_key AS a_to_b_key_1 ON b_key_1.id = a_to_b_key_1.bid) "
+ "ON a.id = a_to_b_key_1.aid"
+ )
+
class JoinExecTest(_JoinRewriteTestBase, fixtures.TestBase):
"""invoke the SQL on the current backend to ensure compatibility"""
- _a_bc = _a_bc_comma_a1_selbc = _a__b_dc = None
+ _a_bc = _a_bc_comma_a1_selbc = _a__b_dc = _a_bkeyassoc = _a_bkeyassoc_aliased = None
@classmethod
def setup_class(cls):
@@ -259,7 +336,9 @@ class JoinExecTest(_JoinRewriteTestBase, fixtures.TestBase):
m.drop_all(testing.db)
def _test(self, selectable, assert_):
- testing.db.execute(selectable)
+ result = testing.db.execute(selectable)
+ for col in selectable.inner_columns:
+ assert col in result._metadata._keymap
class DialectFlagTest(fixtures.TestBase, AssertsCompiledSQL):