diff options
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/crud.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/dml.py | 15 | ||||
| -rw-r--r-- | test/orm/test_cycles.py | 14 | ||||
| -rw-r--r-- | test/orm/test_query.py | 24 | ||||
| -rw-r--r-- | test/orm/test_versioning.py | 2 | ||||
| -rw-r--r-- | test/sql/test_compiler.py | 8 | ||||
| -rw-r--r-- | test/sql/test_update.py | 41 | ||||
| -rw-r--r-- | tox.ini | 1 |
10 files changed, 106 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index d89a93dd3..ea1c08f67 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1220,6 +1220,8 @@ class BulkUpdate(BulkUD): def __init__(self, query, values, update_kwargs): super(BulkUpdate, self).__init__(query) self.values = values + # Accept values as a dictionary or any other iterable of value pairs + self.values = util.OrderedDict(values) self.update_kwargs = update_kwargs @classmethod @@ -1258,7 +1260,7 @@ class BulkUpdate(BulkUD): "Invalid expression type: %r" % key) def _do_exec(self): - values = dict( + values = util.OrderedDict( (self._resolve_string_to_expr(k), v) for k, v in self.values.items() ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 691195772..768d4f83a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1971,7 +1971,8 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - crud_params = crud._get_crud_params(self, update_stmt, **kw) + crud_params = crud._get_crud_params(self, update_stmt, keep_order=True, + **kw) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e6f16b698..614f9413b 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -26,7 +26,7 @@ values present. """) -def _get_crud_params(compiler, stmt, **kw): +def _get_crud_params(compiler, stmt, keep_order=False, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -64,12 +64,12 @@ def _get_crud_params(compiler, stmt, **kw): # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: - parameters = {} + parameters = util.OrderedDict() else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = util.OrderedDict((_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or + key not in stmt_parameters) # create a list of column assignment clauses as tuples values = [] @@ -97,7 +97,7 @@ def _get_crud_params(compiler, stmt, **kw): _scan_cols( compiler, stmt, parameters, _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + _col_bind_name, check_columns, values, kw, keep_order=keep_order) if parameters and stmt_parameters: check = set(parameters).intersection( @@ -202,7 +202,7 @@ def _scan_insert_from_select_cols( def _scan_cols( compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): + _column_as_key, _col_bind_name, check_columns, values, kw, keep_order): need_pks, implicit_returning, \ implicit_return_defaults, postfetch_lastrowid = \ @@ -210,6 +210,16 @@ def _scan_cols( cols = stmt.table.columns + if keep_order: + # Order columns with parameters first, preserving their original order, + # and then the rest of the columns + keys = tuple(parameters.keys()) if parameters else tuple() + table_cols = tuple(cols) + cols = sorted(table_cols, + key=(lambda x: keys.index(_getattr_col_key(x)) + if _getattr_col_key(x) in keys + else len(keys) + table_cols.index(x))) + for c in cols: col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 6756f1554..983fed2b5 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -15,6 +15,7 @@ from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes from .. import util from .. import exc +from sqlalchemy.sql import schema class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): @@ -30,16 +31,26 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): _prefixes = () def _process_colparams(self, parameters): + def is_value_pair_dict(params): + # Check if params is a value list/tuple representing a dictionary + return ( + isinstance(params, (list, tuple)) and + all(isinstance(p, (list, tuple)) and len(p) == 2 and + isinstance(p[0], schema.Column) for p in params)) + def process_single(p): if isinstance(p, (list, tuple)): - return dict( + if is_value_pair_dict(p): + return util.OrderedDict(p) + return util.OrderedDict( (c.key, pval) for c, pval in zip(self.table.c, p) ) else: return p - if (isinstance(parameters, (list, tuple)) and parameters and + if (not is_value_pair_dict(parameters) and + isinstance(parameters, (list, tuple)) and parameters and isinstance(parameters[0], (list, tuple, dict))): if not self._supports_multi_parameters: diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index c95b8d152..9230c7247 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -1181,9 +1181,10 @@ class PostUpdateBatchingTest(fixtures.MappedTest): testing.db, sess.flush, CompiledSQL( - "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, " - "c3_id=:c3_id WHERE parent.id = :parent_id", - lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id, 'c1_id': c12.id, 'c3_id': c31.id} + "UPDATE parent SET c2_id=:c2_id, c1_id=:c1_id, c3_id=:c3_id " + "WHERE parent.id = :parent_id", + lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id, + 'c1_id': c12.id, 'c3_id': c31.id} ) ) @@ -1193,8 +1194,9 @@ class PostUpdateBatchingTest(fixtures.MappedTest): testing.db, sess.flush, CompiledSQL( - "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, " - "c3_id=:c3_id WHERE parent.id = :parent_id", - lambda ctx: {'c2_id': None, 'parent_id': p1.id, 'c1_id': None, 'c3_id': None} + "UPDATE parent SET c2_id=:c2_id, c1_id=:c1_id, c3_id=:c3_id " + "WHERE parent.id = :parent_id", + lambda ctx: {'c2_id': None, 'parent_id': p1.id, + 'c1_id': None, 'c3_id': None} ) ) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index b0501739f..0a57b30ef 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -3810,7 +3810,7 @@ class SessionBindTest(QueryTest): def _assert_bind_args(self, session): get_bind = mock.Mock(side_effect=session.get_bind) with mock.patch.object(session, "get_bind", get_bind): - yield + yield get_bind for call_ in get_bind.mock_calls: is_(call_[1][0], inspect(self.classes.User)) is_not_(call_[2]['clause'], None) @@ -3846,6 +3846,28 @@ class SessionBindTest(QueryTest): session.query(User).filter(User.id == 15).update( {"name": "foob"}, synchronize_session=False) + def test_bulk_update_ordered_dict(self): + User = self.classes.User + session = Session() + + # Do update using ordered dict and check that parameters order is + # preserved + with self._assert_bind_args(session) as mock_args: + session.query(User).filter(User.id == 15).update( + util.OrderedDict((('name', 'foob'), ('id', 123)))) + cols = [c.name for c + in mock_args.mock_calls[0][2]['clause'].parameters.keys()] + assert ['name', 'id'] == cols + + # Now invert the order and use a list instead of an ordered dict and + # check that order is also preserved + with self._assert_bind_args(session) as mock_args: + session.query(User).filter(User.id == 15).update( + [('id', 123), ('name', 'foob')]) + cols = [c.name for c + in mock_args.mock_calls[0][2]['clause'].parameters.keys()] + assert ['id', 'name'] == cols + def test_bulk_delete_no_sync(self): User = self.classes.User session = Session() diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index d46799c5a..108e0bc5a 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -930,7 +930,7 @@ class ServerVersioningTest(fixtures.MappedTest): # "default" - on a "returning" backend, the statement # includes "RETURNING" CompiledSQL( - "UPDATE version_table SET version_id=2, value=:value " + "UPDATE version_table SET value=:value, version_id=2 " "WHERE version_table.id = :version_table_id AND " "version_table.version_id = :version_table_version_id", lambda ctx: [ diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index c957b2f8a..899af86a9 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2764,7 +2764,7 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): u.values( x=3 + bindparam('x')), - "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", + "UPDATE foo SET y=:y, x=(:param_1 + :x) WHERE foo.x = :x", params={ 'x': 1, 'y': 2}) @@ -2951,9 +2951,9 @@ class InlineDefaultTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile(t.update(inline=True, values={'col3': 'foo'}), - "UPDATE test SET col1=foo(:foo_1), col2=(SELECT " - "coalesce(max(foo.id)) AS coalesce_1 FROM foo), " - "col3=:col3") + "UPDATE test SET col3=:col3, col1=foo(:foo_1), " + "col2=(SELECT coalesce(max(foo.id)) AS coalesce_1 " + "FROM foo)") class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 58c86613b..3dd6c99db 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -4,6 +4,7 @@ from sqlalchemy.dialects import mysql from sqlalchemy.engine import default from sqlalchemy.testing import AssertsCompiledSQL, eq_, fixtures from sqlalchemy.testing.schema import Table, Column +from sqlalchemy import util class _UpdateFromTestBase(object): @@ -114,7 +115,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): table1.c.myid == 12, values={table1.c.name: table1.c.myid}), 'UPDATE mytable ' - 'SET name=mytable.myid, description=:description ' + 'SET description=:description, name=mytable.myid ' 'WHERE mytable.myid = :myid_1', params={'description': 'test'}, checkparams={'description': 'test', 'myid_1': 12}) @@ -127,7 +128,8 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): 'UPDATE mytable ' 'SET myid=:myid, description=:description ' 'WHERE mytable.myid = :myid_1', - params={'myid_1': 12, 'myid': 9, 'description': 'test'}) + params=util.OrderedDict(( + ('myid_1', 12), ('myid', 9), ('description', 'test')))) def test_update_8(self): table1 = self.tables.mytable @@ -153,18 +155,41 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): update(table1, table1.c.myid == 12, values=v1).values(v2), 'UPDATE mytable ' 'SET ' - 'name=(mytable.name || :name_1), ' - 'description=:description ' + 'description=:description, ' + 'name=(mytable.name || :name_1) ' 'WHERE mytable.myid = :myid_1', params={'description': 'test'}) def test_update_11(self): table1 = self.tables.mytable - values = { - table1.c.name: table1.c.name + 'lala', - table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho')) - } + values = util.OrderedDict(( + (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))), + (table1.c.name, table1.c.name + 'lala'))) + self.assert_compile( + update( + table1, + (table1.c.myid == func.hoho(4)) & ( + table1.c.name == literal('foo') + + table1.c.name + + literal('lala')), + values=values), + 'UPDATE mytable ' + 'SET ' + 'myid=do_stuff(mytable.myid, :param_1), ' + 'name=(mytable.name || :name_1) ' + 'WHERE ' + 'mytable.myid = hoho(:hoho_1) AND ' + 'mytable.name = :param_2 || mytable.name || :param_3') + + def test_update_12(self): + table1 = self.tables.mytable + + # Confirm that we can pass values not only as dicts and ordered dicts, + # but as value pairs + values = ( + (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))), + (table1.c.name, table1.c.name + 'lala')) self.assert_compile( update( table1, @@ -11,6 +11,7 @@ deps=pytest setenv= PYTHONPATH= PYTHONNOUSERSITE=1 + PYTHONHASHSEED = 0 # we need this because our CI has all the DBAPIs and such # pre-installed in individual site-packages directories. |
