diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-01-20 21:01:35 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2014-01-20 21:17:42 -0500 |
commit | b9318c98637bbd5c19267728fcfe941668345325 (patch) | |
tree | 10e97a4ee419b1eeeb8073ac4e516fd0592fc510 /lib/sqlalchemy/sql/compiler.py | |
parent | f8d45fd5666c6d0285576798ecd4c409909fe810 (diff) | |
download | sqlalchemy-b9318c98637bbd5c19267728fcfe941668345325.tar.gz |
- Fixed the multiple-table "UPDATE..FROM" construct, only usable on
MySQL, to correctly render the SET clause among multiple columns
with the same name across tables. This also changes the name used for
the bound parameter in the SET clause to "<tablename>_<colname>" for
the non-primary table only; as this parameter is typically specified
using the :class:`.Column` object directly this should not have an
impact on applications. The fix takes effect for both
:meth:`.Table.update` as well as :meth:`.Query.update` in the ORM.
[ticket:2912]
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 110 |
1 files changed, 80 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5c5bfad55..4448f7c7b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ from . import schema, sqltypes, operators, functions, \ from .. import util, exc import decimal import itertools +import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -1771,7 +1772,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, extra_froms, **kw) + colparams = self._get_colparams(update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1840,7 +1841,40 @@ class SQLCompiler(Compiled): bindparam._is_crud = True return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt, extra_tables=None, **kw): + @util.memoized_property + def _key_getters_for_crud_column(self): + if self.isupdate and self.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(self.statement._extra_froms) + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + def _get_colparams(self, stmt, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1869,12 +1903,18 @@ class SQLCompiler(Compiled): else: stmt_parameters = stmt.parameters + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + self._key_getters_for_crud_column + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: - parameters = dict((elements._column_as_key(key), REQUIRED) + parameters = dict((_column_as_key(key), REQUIRED) for key in self.column_keys if not stmt_parameters or key not in stmt_parameters) @@ -1884,7 +1924,7 @@ class SQLCompiler(Compiled): if stmt_parameters is not None: for k, v in stmt_parameters.items(): - colkey = elements._column_as_key(k) + colkey = _column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) else: @@ -1892,7 +1932,9 @@ class SQLCompiler(Compiled): # add it to values() in an "as-is" state, # coercing right side to bound param if elements._is_literal(v): - v = self.process(elements.BindParameter(None, v, type_=k.type), **kw) + v = self.process( + elements.BindParameter(None, v, type_=k.type), + **kw) else: v = self.process(v.self_group(), **kw) @@ -1922,24 +1964,25 @@ class SQLCompiler(Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} + # special logic that only occurs for multi-table UPDATE # statements - if extra_tables and stmt_parameters: + if self.isupdate and stmt._extra_froms and stmt_parameters: normalized_params = dict( (elements._clause_element_as_expr(c), param) for c, param in stmt_parameters.items() ) - assert self.isupdate affected_tables = set() - for t in extra_tables: + for t in stmt._extra_froms: for c in t.c: if c in normalized_params: affected_tables.add(t) - check_columns[c.key] = c + check_columns[_getattr_col_key(c)] = c value = normalized_params[c] if elements._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is REQUIRED) + c, value, required=value is REQUIRED, + name=_col_bind_name(c)) else: self.postfetch.append(c) value = self.process(value.self_group(), **kw) @@ -1954,12 +1997,18 @@ class SQLCompiler(Compiled): elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group(), **kw)) + (c, self.process( + c.onupdate.arg.self_group(), + **kw) + ) ) self.postfetch.append(c) else: values.append( - (c, self._create_crud_bind_param(c, None)) + (c, self._create_crud_bind_param( + c, None, name=_col_bind_name(c) + ) + ) ) self.prefetch.append(c) elif c.server_onupdate is not None: @@ -1968,7 +2017,7 @@ class SQLCompiler(Compiled): if self.isinsert and stmt.select_names: # for an insert from select, we can only use names that # are given, so only select for those names. - cols = (stmt.table.c[elements._column_as_key(name)] + cols = (stmt.table.c[_column_as_key(name)] for name in stmt.select_names) else: # iterate through all table columns to maintain @@ -1976,14 +2025,15 @@ class SQLCompiler(Compiled): cols = stmt.table.columns for c in cols: - if c.key in parameters and c.key not in check_columns: - value = parameters.pop(c.key) + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + value = parameters.pop(col_key) if elements._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is REQUIRED, - name=c.key + name=_col_bind_name(c) if not stmt._has_multi_parameters - else "%s_0" % c.key + else "%s_0" % _col_bind_name(c) ) else: if isinstance(value, elements.BindParameter) and \ @@ -2119,12 +2169,12 @@ class SQLCompiler(Compiled): if parameters and stmt_parameters: check = set(parameters).intersection( - elements._column_as_key(k) for k in stmt.parameters + _column_as_key(k) for k in stmt.parameters ).difference(check_columns) if check: raise exc.CompileError( "Unconsumed column names: %s" % - (", ".join(check)) + (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -2133,17 +2183,17 @@ class SQLCompiler(Compiled): values.extend( [ - ( - c, - self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) ) return values |