diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-01-20 21:01:35 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-01-20 21:17:42 -0500 |
commit | b9318c98637bbd5c19267728fcfe941668345325 (patch) | |
tree | 10e97a4ee419b1eeeb8073ac4e516fd0592fc510 | |
parent | f8d45fd5666c6d0285576798ecd4c409909fe810 (diff) | |
download | sqlalchemy-b9318c98637bbd5c19267728fcfe941668345325.tar.gz |
- Fixed the multiple-table "UPDATE..FROM" construct, only usable on
MySQL, to correctly render the SET clause among multiple columns
with the same name across tables. This also changes the name used for
the bound parameter in the SET clause to "<tablename>_<colname>" for
the non-primary table only; as this parameter is typically specified
using the :class:`.Column` object directly this should not have an
impact on applications. The fix takes effect for both
:meth:`.Table.update` as well as :meth:`.Query.update` in the ORM.
[ticket:2912]
-rw-r--r-- | doc/build/changelog/changelog_09.rst | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 110 | ||||
-rw-r--r-- | test/orm/test_update_delete.py | 34 | ||||
-rw-r--r-- | test/sql/test_update.py | 156 |
5 files changed, 270 insertions, 51 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 0efffce62..d59f3ec60 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -83,6 +83,19 @@ Pullreq courtesy Derek Harland. .. change:: + :tags: bug, sql, orm + :tickets: 2912 + + Fixed the multiple-table "UPDATE..FROM" construct, only usable on + MySQL, to correctly render the SET clause among multiple columns + with the same name across tables. This also changes the name used for + the bound parameter in the SET clause to "<tablename>_<colname>" for + the non-primary table only; as this parameter is typically specified + using the :class:`.Column` object directly this should not have an + impact on applications. The fix takes effect for both + :meth:`.Table.update` as well as :meth:`.Query.update` in the ORM. + + .. change:: :tags: bug, oracle :tickets: 2911 diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index e507885fa..ed975b8cf 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -895,6 +895,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): and generate inserted_primary_key collection. """ + key_getter = self.compiled._key_getters_for_crud_column[2] + if self.executemany: if len(self.compiled.prefetch): scalar_defaults = {} @@ -918,7 +920,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): else: val = self.get_update_default(c) if val is not None: - param[c.key] = val + param[key_getter(c)] = val del self.current_parameters else: self.current_parameters = compiled_parameters = \ @@ -931,12 +933,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): val = self.get_update_default(c) if val is not None: - compiled_parameters[c.key] = val + compiled_parameters[key_getter(c)] = val del self.current_parameters if self.isinsert: self.inserted_primary_key = [ - self.compiled_parameters[0].get(c.key, None) + self.compiled_parameters[0].get(key_getter(c), None) for c in self.compiled.\ statement.table.primary_key ] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5c5bfad55..4448f7c7b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ from . import schema, sqltypes, operators, functions, \ from .. import util, exc import decimal import itertools +import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -1771,7 +1772,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, extra_froms, **kw) + colparams = self._get_colparams(update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1840,7 +1841,40 @@ class SQLCompiler(Compiled): bindparam._is_crud = True return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt, extra_tables=None, **kw): + @util.memoized_property + def _key_getters_for_crud_column(self): + if self.isupdate and self.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(self.statement._extra_froms) + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + def _get_colparams(self, stmt, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1869,12 +1903,18 @@ class SQLCompiler(Compiled): else: stmt_parameters = stmt.parameters + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + self._key_getters_for_crud_column + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: - parameters = dict((elements._column_as_key(key), REQUIRED) + parameters = dict((_column_as_key(key), REQUIRED) for key in self.column_keys if not stmt_parameters or key not in stmt_parameters) @@ -1884,7 +1924,7 @@ class SQLCompiler(Compiled): if stmt_parameters is not None: for k, v in stmt_parameters.items(): - colkey = elements._column_as_key(k) + colkey = _column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) else: @@ -1892,7 +1932,9 @@ class SQLCompiler(Compiled): # add it to values() in an "as-is" state, # coercing right side to bound param if elements._is_literal(v): - v = self.process(elements.BindParameter(None, v, type_=k.type), **kw) + v = self.process( + elements.BindParameter(None, v, type_=k.type), + **kw) else: v = self.process(v.self_group(), **kw) @@ -1922,24 +1964,25 @@ class SQLCompiler(Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} + # special logic that only occurs for multi-table UPDATE # statements - if extra_tables and stmt_parameters: + if self.isupdate and stmt._extra_froms and stmt_parameters: normalized_params = dict( (elements._clause_element_as_expr(c), param) for c, param in stmt_parameters.items() ) - assert self.isupdate affected_tables = set() - for t in extra_tables: + for t in stmt._extra_froms: for c in t.c: if c in normalized_params: affected_tables.add(t) - check_columns[c.key] = c + check_columns[_getattr_col_key(c)] = c value = normalized_params[c] if elements._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is REQUIRED) + c, value, required=value is REQUIRED, + name=_col_bind_name(c)) else: self.postfetch.append(c) value = self.process(value.self_group(), **kw) @@ -1954,12 +1997,18 @@ class SQLCompiler(Compiled): elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group(), **kw)) + (c, self.process( + c.onupdate.arg.self_group(), + **kw) + ) ) self.postfetch.append(c) else: values.append( - (c, self._create_crud_bind_param(c, None)) + (c, self._create_crud_bind_param( + c, None, name=_col_bind_name(c) + ) + ) ) self.prefetch.append(c) elif c.server_onupdate is not None: @@ -1968,7 +2017,7 @@ class SQLCompiler(Compiled): if self.isinsert and stmt.select_names: # for an insert from select, we can only use names that # are given, so only select for those names. - cols = (stmt.table.c[elements._column_as_key(name)] + cols = (stmt.table.c[_column_as_key(name)] for name in stmt.select_names) else: # iterate through all table columns to maintain @@ -1976,14 +2025,15 @@ class SQLCompiler(Compiled): cols = stmt.table.columns for c in cols: - if c.key in parameters and c.key not in check_columns: - value = parameters.pop(c.key) + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + value = parameters.pop(col_key) if elements._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is REQUIRED, - name=c.key + name=_col_bind_name(c) if not stmt._has_multi_parameters - else "%s_0" % c.key + else "%s_0" % _col_bind_name(c) ) else: if isinstance(value, elements.BindParameter) and \ @@ -2119,12 +2169,12 @@ class SQLCompiler(Compiled): if parameters and stmt_parameters: check = set(parameters).intersection( - elements._column_as_key(k) for k in stmt.parameters + _column_as_key(k) for k in stmt.parameters ).difference(check_columns) if check: raise exc.CompileError( "Unconsumed column names: %s" % - (", ".join(check)) + (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -2133,17 +2183,17 @@ class SQLCompiler(Compiled): values.extend( [ - ( - c, - self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) ) return values diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 6915ac8a2..ac94fde2f 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -545,12 +545,14 @@ class UpdateDeleteFromTest(fixtures.MappedTest): def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), + Column('samename', String(10)), ) Table('documents', metadata, Column('id', Integer, primary_key=True), Column('user_id', None, ForeignKey('users.id')), Column('title', String(32)), - Column('flag', Boolean) + Column('flag', Boolean), + Column('samename', String(10)), ) @classmethod @@ -659,6 +661,34 @@ class UpdateDeleteFromTest(fixtures.MappedTest): ]) ) + @testing.only_on('mysql', 'Multi table update') + def test_update_from_multitable_same_names(self): + Document = self.classes.Document + User = self.classes.User + + s = Session() + + s.query(Document).\ + filter(User.id == Document.user_id).\ + filter(User.id == 2).update({ + Document.samename: 'd_samename', + User.samename: 'u_samename' + } + ) + eq_( + s.query(User.id, Document.samename, User.samename). + filter(User.id == Document.user_id). + order_by(User.id).all(), + [ + (1, None, None), + (1, None, None), + (2, 'd_samename', 'u_samename'), + (2, 'd_samename', 'u_samename'), + (3, None, None), + (3, None, None), + ] + ) + class ExpressionUpdateTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): @@ -786,3 +816,5 @@ class InheritTest(fixtures.DeclarativeMappedTest): set(s.query(Person.name, Engineer.engineer_name)), set([('e1', 'e1', ), ('e22', 'e55')]) ) + + diff --git a/test/sql/test_update.py b/test/sql/test_update.py index a8510f374..10306372b 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -192,22 +192,6 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): 'UPDATE A B C D mytable SET myid=%s, name=%s, description=%s', dialect=mysql.dialect()) - def test_alias(self): - table1 = self.tables.mytable - talias1 = table1.alias('t1') - - self.assert_compile(update(talias1, talias1.c.myid == 7), - 'UPDATE mytable AS t1 ' - 'SET name=:name ' - 'WHERE t1.myid = :myid_1', - params={table1.c.name: 'fred'}) - - self.assert_compile(update(talias1, table1.c.myid == 7), - 'UPDATE mytable AS t1 ' - 'SET name=:name ' - 'FROM mytable ' - 'WHERE mytable.myid = :myid_1', - params={table1.c.name: 'fred'}) def test_update_to_expression(self): """test update from an expression. @@ -268,6 +252,64 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, run_create_tables = run_inserts = run_deletes = None + def test_alias_one(self): + table1 = self.tables.mytable + talias1 = table1.alias('t1') + + # this case is nonsensical. the UPDATE is entirely + # against the alias, but we name the table-bound column + # in values. The behavior here isn't really defined + self.assert_compile( + update(talias1, talias1.c.myid == 7). + values({table1.c.name: "fred"}), + 'UPDATE mytable AS t1 ' + 'SET name=:name ' + 'WHERE t1.myid = :myid_1') + + def test_alias_two(self): + table1 = self.tables.mytable + talias1 = table1.alias('t1') + + # Here, compared to + # test_alias_one(), here we actually have UPDATE..FROM, + # which is causing the "table1.c.name" param to be handled + # as an "extra table", hence we see the full table name rendered. + self.assert_compile( + update(talias1, table1.c.myid == 7). + values({table1.c.name: 'fred'}), + 'UPDATE mytable AS t1 ' + 'SET name=:mytable_name ' + 'FROM mytable ' + 'WHERE mytable.myid = :myid_1', + checkparams={'mytable_name': 'fred', 'myid_1': 7}, + ) + + def test_alias_two_mysql(self): + table1 = self.tables.mytable + talias1 = table1.alias('t1') + + self.assert_compile( + update(talias1, table1.c.myid == 7). + values({table1.c.name: 'fred'}), + "UPDATE mytable AS t1, mytable SET mytable.name=%s " + "WHERE mytable.myid = %s", + checkparams={'mytable_name': 'fred', 'myid_1': 7}, + dialect='mysql') + + def test_update_from_multitable_same_name_mysql(self): + users, addresses = self.tables.users, self.tables.addresses + + self.assert_compile( + users.update(). + values(name='newname').\ + values({addresses.c.name: "new address"}).\ + where(users.c.id == addresses.c.user_id), + "UPDATE users, addresses SET addresses.name=%s, " + "users.name=%s WHERE users.id = addresses.user_id", + checkparams={u'addresses_name': 'new address', 'name': 'newname'}, + dialect='mysql' + ) + def test_render_table(self): users, addresses = self.tables.users, self.tables.addresses @@ -455,6 +497,36 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (10, 'chuck')] self._assert_users(users, expected) + @testing.only_on('mysql', 'Multi table update') + def test_exec_multitable_same_name(self): + users, addresses = self.tables.users, self.tables.addresses + + values = { + addresses.c.name: 'ad_ed2', + users.c.name: 'ed2' + } + + testing.db.execute( + addresses.update(). + values(values). + where(users.c.id == addresses.c.user_id). + where(users.c.name == 'ed')) + + expected = [ + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'ad_ed2', 'ed@wood.com'), + (3, 8, 'ad_ed2', 'ed@bettyboop.com'), + (4, 8, 'ad_ed2', 'ed@lala.com'), + (5, 9, 'x', 'fred@fred.com')] + self._assert_addresses(addresses, expected) + + expected = [ + (7, 'jack'), + (8, 'ed2'), + (9, 'fred'), + (10, 'chuck')] + self._assert_users(users, expected) + def _assert_addresses(self, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) eq_(testing.db.execute(stmt).fetchall(), expected) @@ -478,7 +550,16 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False)) + Column('email_address', String(50), nullable=False), + ) + + Table('foobar', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('user_id', None, ForeignKey('users.id')), + Column('data', String(30)), + Column('some_update', String(30), onupdate='im the other update') + ) @classmethod def fixtures(cls): @@ -494,6 +575,12 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, (3, 8, 'ed@bettyboop.com'), (4, 9, 'fred@fred.com') ), + foobar=( + ('id', 'user_id', 'data'), + (2, 8, 'd1'), + (3, 8, 'd2'), + (4, 9, 'd3') + ) ) @testing.only_on('mysql', 'Multi table update') @@ -525,6 +612,37 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, self._assert_users(users, expected) @testing.only_on('mysql', 'Multi table update') + def test_defaults_second_table_same_name(self): + users, foobar = self.tables.users, self.tables.foobar + + values = { + foobar.c.data: foobar.c.data + 'a', + users.c.name: 'ed2' + } + + ret = testing.db.execute( + users.update(). + values(values). + where(users.c.id == foobar.c.user_id). + where(users.c.name == 'ed')) + + eq_( + set(ret.prefetch_cols()), + set([users.c.some_update, foobar.c.some_update]) + ) + + expected = [ + (2, 8, 'd1a', 'im the other update'), + (3, 8, 'd2a', 'im the other update'), + (4, 9, 'd3', None)] + self._assert_foobar(foobar, expected) + + expected = [ + (8, 'ed2', 'im the update'), + (9, 'fred', 'value')] + self._assert_users(users, expected) + + @testing.only_on('mysql', 'Multi table update') def test_no_defaults_second_table(self): users, addresses = self.tables.users, self.tables.addresses @@ -548,6 +666,10 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, (9, 'fred', 'value')] self._assert_users(users, expected) + def _assert_foobar(self, foobar, expected): + stmt = foobar.select().order_by(foobar.c.id) + eq_(testing.db.execute(stmt).fetchall(), expected) + def _assert_addresses(self, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) eq_(testing.db.execute(stmt).fetchall(), expected) |