diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2016-09-20 11:56:05 -0400 |
---|---|---|
committer | Gerrit Code Review <gerrit2@ln3.zzzcomputing.com> | 2016-09-20 11:56:05 -0400 |
commit | 5af3c903368e9a437a6ceafce5dc993433420cc7 (patch) | |
tree | 65a957c6154b6d43a001241330d756bf70d8efc3 | |
parent | b9a7a74d5e729408fcac86fe2919aa423c59d863 (diff) | |
parent | f8ecdf47f0975b8b4e357fde2008d9aae8c50239 (diff) | |
download | sqlalchemy-5af3c903368e9a437a6ceafce5dc993433420cc7.tar.gz |
Merge "Allow SQL expressions to be set on PK columns"
-rw-r--r-- | doc/build/changelog/changelog_11.rst | 10 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 28 | ||||
-rw-r--r-- | test/orm/test_naturalpks.py | 37 | ||||
-rw-r--r-- | test/orm/test_versioning.py | 32 |
4 files changed, 98 insertions, 9 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst index 0d6c6b159..efd2d33f9 100644 --- a/doc/build/changelog/changelog_11.rst +++ b/doc/build/changelog/changelog_11.rst @@ -44,6 +44,16 @@ .. change:: :tags: bug, orm + :tickets: 3801 + + An UPDATE emitted from the ORM flush process can now accommodate a + SQL expression element for a column within the primary key of an + object, if the target database supports RETURNING in order to provide + the new value, or if the PK value is set "to itself" for the purposes + of bumping some other trigger / onupdate on the column. + + .. change:: + :tags: bug, orm :tickets: 3788 Fixed bug where the "simple many-to-one" condition that allows lazy diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0b029f466..56b028375 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -506,6 +506,7 @@ def _collect_update_commands( elif not (params or value_params): continue + has_all_pks = True if bulk: pk_params = dict( (propkey_to_col[propkey]._label, state_dict.get(propkey)) @@ -530,7 +531,8 @@ def _collect_update_commands( else: # else, use the old value to locate the row pk_params[col._label] = history.deleted[0] - params[col.key] = history.added[0] + if col in value_params: + has_all_pks = False else: pk_params[col._label] = history.unchanged[0] if pk_params[col._label] is None: @@ -542,7 +544,7 @@ def _collect_update_commands( params.update(pk_params) yield ( state, state_dict, params, mapper, - connection, value_params, has_all_defaults) + connection, value_params, has_all_defaults, has_all_pks) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -636,14 +638,15 @@ def _emit_update_statements(base_mapper, uowtransaction, cached_stmt = base_mapper._memo(('update', table), update_stmt) - for (connection, paramkeys, hasvalue, has_all_defaults), \ + for (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), \ records in groupby( update, lambda rec: ( rec[4], # connection set(rec[2]), # set of parameter keys bool(rec[5]), # whether or not we have "value" parameters - rec[6] # has_all_defaults + rec[6], # has_all_defaults + rec[7] # has all pks ) ): rows = 0 @@ -659,7 +662,9 @@ def _emit_update_statements(base_mapper, uowtransaction, connection.dialect.supports_sane_multi_rowcount allow_multirow = has_all_defaults and not needs_version_id - if bookkeeping and not has_all_defaults and \ + if not has_all_pks: + statement = statement.return_defaults() + elif bookkeeping and not has_all_defaults and \ mapper.base_mapper.eager_defaults: statement = statement.return_defaults() elif mapper.version_id_col is not None: @@ -667,7 +672,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if hasvalue: for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: c = connection.execute( statement.values(value_params), params) @@ -687,7 +693,8 @@ def _emit_update_statements(base_mapper, uowtransaction, if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, has_all_defaults, \ + has_all_pks in records: c = cached_connections[connection].\ execute(statement, params) @@ -717,7 +724,8 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount for state, state_dict, params, mapper, \ - connection, value_params, has_all_defaults in records: + connection, value_params, \ + has_all_defaults, has_all_pks in records: if bookkeeping: _postfetch( mapper, @@ -1013,7 +1021,9 @@ def _postfetch(mapper, uowtransaction, table, row = result.context.returned_defaults if row is not None: for col in returning_cols: - if col.primary_key: + # pk cols returned from insert are handled + # distinctly, don't step on the values here + if col.primary_key and result.context.isinsert: continue dict_[mapper._columntoproperty[col].key] = row[col] if refresh_flush: diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 60387ddce..6780967c9 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -122,6 +122,43 @@ class NaturalPKTest(fixtures.MappedTest): assert sess.query(User).get('jack') is None assert sess.query(User).get('ed').fullname == 'jack' + @testing.requires.returning + def test_update_to_sql_expr(self): + users, User = self.tables.users, self.classes.User + + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.add(u1) + sess.flush() + + u1.username = User.username + ' jones' + + sess.flush() + + eq_(u1.username, 'jack jones') + + def test_update_to_self_sql_expr(self): + # SQL expression where the PK won't actually change, + # such as to bump a server side trigger + users, User = self.tables.users, self.classes.User + + mapper(User, users) + + sess = create_session() + u1 = User(username='jack', fullname='jack') + + sess.add(u1) + sess.flush() + + u1.username = User.username + '' + + sess.flush() + + eq_(u1.username, 'jack') + def test_flush_new_pk_after_expire(self): User, users = self.classes.User, self.tables.users diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 07b090c60..40b373097 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -1020,6 +1020,38 @@ class ServerVersioningTest(fixtures.MappedTest): ) self.assert_sql_execution(testing.db, sess.flush, *statements) + def test_sql_expr_bump(self): + sess = self._fixture() + + f1 = self.classes.Foo(value='f1') + sess.add(f1) + sess.flush() + + eq_(f1.version_id, 1) + + f1.id = self.classes.Foo.id + 0 + + sess.flush() + + eq_(f1.version_id, 2) + + @testing.requires.returning + def test_sql_expr_w_mods_bump(self): + sess = self._fixture() + + f1 = self.classes.Foo(id=2, value='f1') + sess.add(f1) + sess.flush() + + eq_(f1.version_id, 1) + + f1.id = self.classes.Foo.id + 3 + + sess.flush() + + eq_(f1.id, 5) + eq_(f1.version_id, 2) + def test_multi_update(self): sess = self._fixture() |