summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-11-22 18:05:05 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2011-11-22 18:05:05 -0500
commit4de3b28abce67a09dfde1cffd8a244b6542ae8c1 (patch)
treea5ec8ef126ce6dc82eff4e8d9e7393ac80dcbb77
parent90b6ca30e430a06ed1d1696f3881ae72c6014ecd (diff)
downloadsqlalchemy-4de3b28abce67a09dfde1cffd8a244b6542ae8c1.tar.gz
fixes to actually get tests to pass
-rw-r--r--lib/sqlalchemy/sql/compiler.py29
-rw-r--r--lib/sqlalchemy/sql/expression.py14
-rw-r--r--test/aaa_profiling/test_compiler.py4
-rw-r--r--test/sql/test_update.py12
4 files changed, 37 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 24c3687e9..4b1b9bd5d 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1025,11 +1025,7 @@ class SQLCompiler(engine.Compiled):
self.isupdate = True
- if update_stmt._whereclause is not None:
- extra_froms = set(update_stmt._whereclause._from_objects).\
- difference([update_stmt.table])
- else:
- extra_froms = None
+ extra_froms = update_stmt._extra_froms
colparams = self._get_colparams(update_stmt, extra_froms)
@@ -1038,20 +1034,17 @@ class SQLCompiler(engine.Compiled):
update_stmt.table,
extra_froms, **kw)
+ text += ' SET '
if extra_froms and self.render_table_with_column_in_update_from:
- text += ' SET ' + \
- ', '.join(
+ text += ', '.join(
self.visit_column(c[0]) +
- '=' + c[1]
- for c in colparams
- )
+ '=' + c[1] for c in colparams
+ )
else:
- text += ' SET ' + \
- ', '.join(
+ text += ', '.join(
self.preparer.quote(c[0].name, c[0].quote) +
- '=' + c[1]
- for c in colparams
- )
+ '=' + c[1] for c in colparams
+ )
if update_stmt._returning:
self.returning = update_stmt._returning
@@ -1144,6 +1137,8 @@ class SQLCompiler(engine.Compiled):
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
+ # special logic that only occurs for multi-table UPDATE
+ # statements
if extra_tables and stmt.parameters:
for t in extra_tables:
for c in t.c:
@@ -1186,7 +1181,7 @@ class SQLCompiler(engine.Compiled):
(
implicit_returning or
not postfetch_lastrowid or
- c is not t._autoincrement_column
+ c is not stmt.table._autoincrement_column
):
if implicit_returning:
@@ -1213,7 +1208,7 @@ class SQLCompiler(engine.Compiled):
self.returning.append(c)
else:
if c.default is not None or \
- c is t._autoincrement_column and (
+ c is stmt.table._autoincrement_column and (
self.dialect.supports_sequences or
self.dialect.preexecute_autoincrement_sequences
):
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 6520be202..6eb4367b3 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -5292,6 +5292,20 @@ class Update(ValuesBase):
else:
self._whereclause = _literal_as_text(whereclause)
+ @property
+ def _extra_froms(self):
+ # TODO: this could be made memoized
+ # if the memoization is reset on each generative call.
+ froms = []
+ seen = set([self.table])
+
+ if self._whereclause is not None:
+ for item in _from_objects(self._whereclause):
+ if not seen.intersection(item._cloned_set):
+ froms.append(item)
+ seen.update(item._cloned_set)
+
+ return froms
class Delete(UpdateBase):
"""Represent a DELETE construct.
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py
index f949ce6ea..a7ce7a70b 100644
--- a/test/aaa_profiling/test_compiler.py
+++ b/test/aaa_profiling/test_compiler.py
@@ -39,11 +39,11 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
def test_insert(self):
t1.insert().compile(dialect=self.dialect)
- @profiling.function_call_count(versions={'2.6':53, '2.7':53})
+ @profiling.function_call_count(versions={'2.6':56, '2.7':56})
def test_update(self):
t1.update().compile(dialect=self.dialect)
- @profiling.function_call_count(versions={'2.6':110, '2.7':110, '3':115})
+ @profiling.function_call_count(versions={'2.6':117, '2.7':117, '3':118})
def test_update_whereclause(self):
t1.update().where(t1.c.c2==12).compile(dialect=self.dialect)
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index 87fd6ffd5..2ea3d92a4 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -7,9 +7,7 @@ from test.lib import *
from test.lib.schema import Table, Column
from sqlalchemy.dialects import mysql
-class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL):
- __dialect__ = 'default'
-
+class _UpdateFromTestBase(object):
@classmethod
def define_tables(cls, metadata):
Table('users', metadata,
@@ -65,6 +63,12 @@ class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL):
),
)
+
+class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
+ __dialect__ = 'default'
+
+ run_create_tables = run_inserts = run_deletes = None
+
def test_render_table(self):
users, addresses = self.tables.users, self.tables.addresses
self.assert_compile(
@@ -134,6 +138,8 @@ class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL):
u'id_1': 7, 'name': 'newname'}
)
+class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
+
@testing.requires.update_from
def test_exec_two_table(self):
users, addresses = self.tables.users, self.tables.addresses