summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/persistence.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py3
-rw-r--r--lib/sqlalchemy/sql/crud.py26
-rw-r--r--lib/sqlalchemy/sql/dml.py15
-rw-r--r--test/orm/test_cycles.py14
-rw-r--r--test/orm/test_query.py24
-rw-r--r--test/orm/test_versioning.py2
-rw-r--r--test/sql/test_compiler.py8
-rw-r--r--test/sql/test_update.py41
-rw-r--r--tox.ini1
10 files changed, 106 insertions, 32 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index d89a93dd3..ea1c08f67 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -1220,6 +1220,8 @@ class BulkUpdate(BulkUD):
def __init__(self, query, values, update_kwargs):
super(BulkUpdate, self).__init__(query)
self.values = values
+ # Accept values as a dictionary or any other iterable of value pairs
+ self.values = util.OrderedDict(values)
self.update_kwargs = update_kwargs
@classmethod
@@ -1258,7 +1260,7 @@ class BulkUpdate(BulkUD):
"Invalid expression type: %r" % key)
def _do_exec(self):
- values = dict(
+ values = util.OrderedDict(
(self._resolve_string_to_expr(k), v)
for k, v in self.values.items()
)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 691195772..768d4f83a 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1971,7 +1971,8 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- crud_params = crud._get_crud_params(self, update_stmt, **kw)
+ crud_params = crud._get_crud_params(self, update_stmt, keep_order=True,
+ **kw)
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index e6f16b698..614f9413b 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -26,7 +26,7 @@ values present.
""")
-def _get_crud_params(compiler, stmt, **kw):
+def _get_crud_params(compiler, stmt, keep_order=False, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -64,12 +64,12 @@ def _get_crud_params(compiler, stmt, **kw):
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
- parameters = {}
+ parameters = util.OrderedDict()
else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = util.OrderedDict((_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or
+ key not in stmt_parameters)
# create a list of column assignment clauses as tuples
values = []
@@ -97,7 +97,7 @@ def _get_crud_params(compiler, stmt, **kw):
_scan_cols(
compiler, stmt, parameters,
_getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ _col_bind_name, check_columns, values, kw, keep_order=keep_order)
if parameters and stmt_parameters:
check = set(parameters).intersection(
@@ -202,7 +202,7 @@ def _scan_insert_from_select_cols(
def _scan_cols(
compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
+ _column_as_key, _col_bind_name, check_columns, values, kw, keep_order):
need_pks, implicit_returning, \
implicit_return_defaults, postfetch_lastrowid = \
@@ -210,6 +210,16 @@ def _scan_cols(
cols = stmt.table.columns
+ if keep_order:
+ # Order columns with parameters first, preserving their original order,
+ # and then the rest of the columns
+ keys = tuple(parameters.keys()) if parameters else tuple()
+ table_cols = tuple(cols)
+ cols = sorted(table_cols,
+ key=(lambda x: keys.index(_getattr_col_key(x))
+ if _getattr_col_key(x) in keys
+ else len(keys) + table_cols.index(x)))
+
for c in cols:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 6756f1554..983fed2b5 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -15,6 +15,7 @@ from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes
from .. import util
from .. import exc
+from sqlalchemy.sql import schema
class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement):
@@ -30,16 +31,26 @@ class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement):
_prefixes = ()
def _process_colparams(self, parameters):
+ def is_value_pair_dict(params):
+ # Check if params is a value list/tuple representing a dictionary
+ return (
+ isinstance(params, (list, tuple)) and
+ all(isinstance(p, (list, tuple)) and len(p) == 2 and
+ isinstance(p[0], schema.Column) for p in params))
+
def process_single(p):
if isinstance(p, (list, tuple)):
- return dict(
+ if is_value_pair_dict(p):
+ return util.OrderedDict(p)
+ return util.OrderedDict(
(c.key, pval)
for c, pval in zip(self.table.c, p)
)
else:
return p
- if (isinstance(parameters, (list, tuple)) and parameters and
+ if (not is_value_pair_dict(parameters) and
+ isinstance(parameters, (list, tuple)) and parameters and
isinstance(parameters[0], (list, tuple, dict))):
if not self._supports_multi_parameters:
diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py
index c95b8d152..9230c7247 100644
--- a/test/orm/test_cycles.py
+++ b/test/orm/test_cycles.py
@@ -1181,9 +1181,10 @@ class PostUpdateBatchingTest(fixtures.MappedTest):
testing.db,
sess.flush,
CompiledSQL(
- "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, "
- "c3_id=:c3_id WHERE parent.id = :parent_id",
- lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id, 'c1_id': c12.id, 'c3_id': c31.id}
+ "UPDATE parent SET c2_id=:c2_id, c1_id=:c1_id, c3_id=:c3_id "
+ "WHERE parent.id = :parent_id",
+ lambda ctx: {'c2_id': c23.id, 'parent_id': p1.id,
+ 'c1_id': c12.id, 'c3_id': c31.id}
)
)
@@ -1193,8 +1194,9 @@ class PostUpdateBatchingTest(fixtures.MappedTest):
testing.db,
sess.flush,
CompiledSQL(
- "UPDATE parent SET c1_id=:c1_id, c2_id=:c2_id, "
- "c3_id=:c3_id WHERE parent.id = :parent_id",
- lambda ctx: {'c2_id': None, 'parent_id': p1.id, 'c1_id': None, 'c3_id': None}
+ "UPDATE parent SET c2_id=:c2_id, c1_id=:c1_id, c3_id=:c3_id "
+ "WHERE parent.id = :parent_id",
+ lambda ctx: {'c2_id': None, 'parent_id': p1.id,
+ 'c1_id': None, 'c3_id': None}
)
)
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index b0501739f..0a57b30ef 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -3810,7 +3810,7 @@ class SessionBindTest(QueryTest):
def _assert_bind_args(self, session):
get_bind = mock.Mock(side_effect=session.get_bind)
with mock.patch.object(session, "get_bind", get_bind):
- yield
+ yield get_bind
for call_ in get_bind.mock_calls:
is_(call_[1][0], inspect(self.classes.User))
is_not_(call_[2]['clause'], None)
@@ -3846,6 +3846,28 @@ class SessionBindTest(QueryTest):
session.query(User).filter(User.id == 15).update(
{"name": "foob"}, synchronize_session=False)
+ def test_bulk_update_ordered_dict(self):
+ User = self.classes.User
+ session = Session()
+
+ # Do update using ordered dict and check that parameters order is
+ # preserved
+ with self._assert_bind_args(session) as mock_args:
+ session.query(User).filter(User.id == 15).update(
+ util.OrderedDict((('name', 'foob'), ('id', 123))))
+ cols = [c.name for c
+ in mock_args.mock_calls[0][2]['clause'].parameters.keys()]
+ assert ['name', 'id'] == cols
+
+ # Now invert the order and use a list instead of an ordered dict and
+ # check that order is also preserved
+ with self._assert_bind_args(session) as mock_args:
+ session.query(User).filter(User.id == 15).update(
+ [('id', 123), ('name', 'foob')])
+ cols = [c.name for c
+ in mock_args.mock_calls[0][2]['clause'].parameters.keys()]
+ assert ['id', 'name'] == cols
+
def test_bulk_delete_no_sync(self):
User = self.classes.User
session = Session()
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index d46799c5a..108e0bc5a 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -930,7 +930,7 @@ class ServerVersioningTest(fixtures.MappedTest):
# "default" - on a "returning" backend, the statement
# includes "RETURNING"
CompiledSQL(
- "UPDATE version_table SET version_id=2, value=:value "
+ "UPDATE version_table SET value=:value, version_id=2 "
"WHERE version_table.id = :version_table_id AND "
"version_table.version_id = :version_table_version_id",
lambda ctx: [
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index c957b2f8a..899af86a9 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -2764,7 +2764,7 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL):
u.values(
x=3 +
bindparam('x')),
- "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x",
+ "UPDATE foo SET y=:y, x=(:param_1 + :x) WHERE foo.x = :x",
params={
'x': 1,
'y': 2})
@@ -2951,9 +2951,9 @@ class InlineDefaultTest(fixtures.TestBase, AssertsCompiledSQL):
)
self.assert_compile(t.update(inline=True, values={'col3': 'foo'}),
- "UPDATE test SET col1=foo(:foo_1), col2=(SELECT "
- "coalesce(max(foo.id)) AS coalesce_1 FROM foo), "
- "col3=:col3")
+ "UPDATE test SET col3=:col3, col1=foo(:foo_1), "
+ "col2=(SELECT coalesce(max(foo.id)) AS coalesce_1 "
+ "FROM foo)")
class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index 58c86613b..3dd6c99db 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -4,6 +4,7 @@ from sqlalchemy.dialects import mysql
from sqlalchemy.engine import default
from sqlalchemy.testing import AssertsCompiledSQL, eq_, fixtures
from sqlalchemy.testing.schema import Table, Column
+from sqlalchemy import util
class _UpdateFromTestBase(object):
@@ -114,7 +115,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
table1.c.myid == 12,
values={table1.c.name: table1.c.myid}),
'UPDATE mytable '
- 'SET name=mytable.myid, description=:description '
+ 'SET description=:description, name=mytable.myid '
'WHERE mytable.myid = :myid_1',
params={'description': 'test'},
checkparams={'description': 'test', 'myid_1': 12})
@@ -127,7 +128,8 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
'UPDATE mytable '
'SET myid=:myid, description=:description '
'WHERE mytable.myid = :myid_1',
- params={'myid_1': 12, 'myid': 9, 'description': 'test'})
+ params=util.OrderedDict((
+ ('myid_1', 12), ('myid', 9), ('description', 'test'))))
def test_update_8(self):
table1 = self.tables.mytable
@@ -153,18 +155,41 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
update(table1, table1.c.myid == 12, values=v1).values(v2),
'UPDATE mytable '
'SET '
- 'name=(mytable.name || :name_1), '
- 'description=:description '
+ 'description=:description, '
+ 'name=(mytable.name || :name_1) '
'WHERE mytable.myid = :myid_1',
params={'description': 'test'})
def test_update_11(self):
table1 = self.tables.mytable
- values = {
- table1.c.name: table1.c.name + 'lala',
- table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho'))
- }
+ values = util.OrderedDict((
+ (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))),
+ (table1.c.name, table1.c.name + 'lala')))
+ self.assert_compile(
+ update(
+ table1,
+ (table1.c.myid == func.hoho(4)) & (
+ table1.c.name == literal('foo') +
+ table1.c.name +
+ literal('lala')),
+ values=values),
+ 'UPDATE mytable '
+ 'SET '
+ 'myid=do_stuff(mytable.myid, :param_1), '
+ 'name=(mytable.name || :name_1) '
+ 'WHERE '
+ 'mytable.myid = hoho(:hoho_1) AND '
+ 'mytable.name = :param_2 || mytable.name || :param_3')
+
+ def test_update_12(self):
+ table1 = self.tables.mytable
+
+ # Confirm that we can pass values not only as dicts and ordered dicts,
+ # but as value pairs
+ values = (
+ (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))),
+ (table1.c.name, table1.c.name + 'lala'))
self.assert_compile(
update(
table1,
diff --git a/tox.ini b/tox.ini
index 2bb589207..627988166 100644
--- a/tox.ini
+++ b/tox.ini
@@ -11,6 +11,7 @@ deps=pytest
setenv=
PYTHONPATH=
PYTHONNOUSERSITE=1
+ PYTHONHASHSEED = 0
# we need this because our CI has all the DBAPIs and such
# pre-installed in individual site-packages directories.