summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2016-09-20 11:33:16 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2016-09-20 11:33:16 -0400
commitf8ecdf47f0975b8b4e357fde2008d9aae8c50239 (patch)
tree3efea46680fc5ca957387542f21b5e52eaf1a737
parent881369b949cff44e0017fdc28d9722ef3c26171a (diff)
downloadsqlalchemy-f8ecdf47f0975b8b4e357fde2008d9aae8c50239.tar.gz
Allow SQL expressions to be set on PK columns
Removes an unnecessary transfer of modified PK column value to the params dictionary, so that if the modified PK column is already present in value_params, this remains in effect. Also propagate a new flag through to _emit_update_statements() that will trip "return_defaults()" across the board if a PK col w/ SQL expression change is present, and pull this PK value in _postfetch as well assuming we're an UPDATE. Change-Id: I9ae87f964df9ba8faea8e25e96b8327f968e5d1b Fixes: #3801
-rw-r--r--doc/build/changelog/changelog_11.rst10
-rw-r--r--lib/sqlalchemy/orm/persistence.py28
-rw-r--r--test/orm/test_naturalpks.py37
-rw-r--r--test/orm/test_versioning.py32
4 files changed, 98 insertions, 9 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst
index a09703489..6aa5624dd 100644
--- a/doc/build/changelog/changelog_11.rst
+++ b/doc/build/changelog/changelog_11.rst
@@ -23,6 +23,16 @@
.. change::
:tags: bug, orm
+ :tickets: 3801
+
+ An UPDATE emitted from the ORM flush process can now accommodate a
+ SQL expression element for a column within the primary key of an
+ object, if the target database supports RETURNING in order to provide
+ the new value, or if the PK value is set "to itself" for the purposes
+ of bumping some other trigger / onupdate on the column.
+
+ .. change::
+ :tags: bug, orm
:tickets: 3788
Fixed bug where the "simple many-to-one" condition that allows lazy
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 0b029f466..56b028375 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -506,6 +506,7 @@ def _collect_update_commands(
elif not (params or value_params):
continue
+ has_all_pks = True
if bulk:
pk_params = dict(
(propkey_to_col[propkey]._label, state_dict.get(propkey))
@@ -530,7 +531,8 @@ def _collect_update_commands(
else:
# else, use the old value to locate the row
pk_params[col._label] = history.deleted[0]
- params[col.key] = history.added[0]
+ if col in value_params:
+ has_all_pks = False
else:
pk_params[col._label] = history.unchanged[0]
if pk_params[col._label] is None:
@@ -542,7 +544,7 @@ def _collect_update_commands(
params.update(pk_params)
yield (
state, state_dict, params, mapper,
- connection, value_params, has_all_defaults)
+ connection, value_params, has_all_defaults, has_all_pks)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction,
cached_stmt = base_mapper._memo(('update', table), update_stmt)
- for (connection, paramkeys, hasvalue, has_all_defaults), \
+ for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \
records in groupby(
update,
lambda rec: (
rec[4], # connection
set(rec[2]), # set of parameter keys
bool(rec[5]), # whether or not we have "value" parameters
- rec[6] # has_all_defaults
+ rec[6], # has_all_defaults
+ rec[7] # has all pks
)
):
rows = 0
@@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction,
connection.dialect.supports_sane_multi_rowcount
allow_multirow = has_all_defaults and not needs_version_id
- if bookkeeping and not has_all_defaults and \
+ if not has_all_pks:
+ statement = statement.return_defaults()
+ elif bookkeeping and not has_all_defaults and \
mapper.base_mapper.eager_defaults:
statement = statement.return_defaults()
elif mapper.version_id_col is not None:
@@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
if hasvalue:
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
c = connection.execute(
statement.values(value_params),
params)
@@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, has_all_defaults, \
+ has_all_pks in records:
c = cached_connections[connection].\
execute(statement, params)
@@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows += c.rowcount
for state, state_dict, params, mapper, \
- connection, value_params, has_all_defaults in records:
+ connection, value_params, \
+ has_all_defaults, has_all_pks in records:
if bookkeeping:
_postfetch(
mapper,
@@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table,
row = result.context.returned_defaults
if row is not None:
for col in returning_cols:
- if col.primary_key:
+ # pk cols returned from insert are handled
+ # distinctly, don't step on the values here
+ if col.primary_key and result.context.isinsert:
continue
dict_[mapper._columntoproperty[col].key] = row[col]
if refresh_flush:
diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py
index 60387ddce..6780967c9 100644
--- a/test/orm/test_naturalpks.py
+++ b/test/orm/test_naturalpks.py
@@ -122,6 +122,43 @@ class NaturalPKTest(fixtures.MappedTest):
assert sess.query(User).get('jack') is None
assert sess.query(User).get('ed').fullname == 'jack'
+ @testing.requires.returning
+ def test_update_to_sql_expr(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = create_session()
+ u1 = User(username='jack', fullname='jack')
+
+ sess.add(u1)
+ sess.flush()
+
+ u1.username = User.username + ' jones'
+
+ sess.flush()
+
+ eq_(u1.username, 'jack jones')
+
+ def test_update_to_self_sql_expr(self):
+ # SQL expression where the PK won't actually change,
+ # such as to bump a server side trigger
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = create_session()
+ u1 = User(username='jack', fullname='jack')
+
+ sess.add(u1)
+ sess.flush()
+
+ u1.username = User.username + ''
+
+ sess.flush()
+
+ eq_(u1.username, 'jack')
+
def test_flush_new_pk_after_expire(self):
User, users = self.classes.User, self.tables.users
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index 07b090c60..40b373097 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -1020,6 +1020,38 @@ class ServerVersioningTest(fixtures.MappedTest):
)
self.assert_sql_execution(testing.db, sess.flush, *statements)
+ def test_sql_expr_bump(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(value='f1')
+ sess.add(f1)
+ sess.flush()
+
+ eq_(f1.version_id, 1)
+
+ f1.id = self.classes.Foo.id + 0
+
+ sess.flush()
+
+ eq_(f1.version_id, 2)
+
+ @testing.requires.returning
+ def test_sql_expr_w_mods_bump(self):
+ sess = self._fixture()
+
+ f1 = self.classes.Foo(id=2, value='f1')
+ sess.add(f1)
+ sess.flush()
+
+ eq_(f1.version_id, 1)
+
+ f1.id = self.classes.Foo.id + 3
+
+ sess.flush()
+
+ eq_(f1.id, 5)
+ eq_(f1.version_id, 2)
+
def test_multi_update(self):
sess = self._fixture()