summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-12-14 17:24:47 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2015-12-14 17:30:21 -0500
commit0e4c4d7efc08d04c3c0ae960428b08ada37e4a91 (patch)
tree4421c6681b9bc6025c5baccffbe5d61b901c48da
parent7d96ad4d535dc02a8ab1384df1db94dea2a045b5 (diff)
downloadsqlalchemy-0e4c4d7efc08d04c3c0ae960428b08ada37e4a91.tar.gz
- Fixed bug in :meth:`.Update.return_defaults` which would cause all
insert-default holding columns not otherwise included in the SET clause (such as primary key cols) to get rendered into the RETURNING even though this is an UPDATE. - Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this flag would not be honored correctly in the case that multiple UPDATE statements were to be emitted, either as part of a flush or a bulk update operation. Additionally, RETURNING would be emitted unnecessarily within update statements. fixes #3609
-rw-r--r--doc/build/changelog/changelog_10.rst21
-rw-r--r--lib/sqlalchemy/orm/mapper.py14
-rw-r--r--lib/sqlalchemy/orm/persistence.py36
-rw-r--r--lib/sqlalchemy/sql/crud.py1
-rw-r--r--lib/sqlalchemy/testing/assertsql.py17
-rw-r--r--test/orm/test_unitofworkv2.py447
-rw-r--r--test/orm/test_versioning.py19
-rw-r--r--test/sql/test_returning.py25
8 files changed, 559 insertions, 21 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst
index 950046cd0..974aa5f1a 100644
--- a/doc/build/changelog/changelog_10.rst
+++ b/doc/build/changelog/changelog_10.rst
@@ -19,6 +19,27 @@
:version: 1.0.11
.. change::
+ :tags: bug, sql
+ :tickets: 3609
+ :versions: 1.1.0b1
+
+ Fixed bug in :meth:`.Update.return_defaults` which would cause all
+ insert-default holding columns not otherwise included in the SET
+ clause (such as primary key cols) to get rendered into the RETURNING
+ even though this is an UPDATE.
+
+ .. change::
+ :tags: bug, orm
+ :tickets: 3609
+ :versions: 1.1.0b1
+
+ Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this
+ flag would not be honored correctly in the case that multiple
+ UPDATE statements were to be emitted, either as part of a flush
+ or a bulk update operation. Additionally, RETURNING
+ would be emitted unnecessarily within update statements.
+
+ .. change::
:tags: bug, orm
:tickets: 3606
:versions: 1.1.0b1
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 5ade4b966..95aa14a26 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1970,12 +1970,24 @@ class Mapper(InspectionAttr):
(
table,
frozenset([
- col for col in columns
+ col.key for col in columns
if col.server_default is not None])
)
for table, columns in self._cols_by_table.items()
)
+ @_memoized_configured_property
+ def _server_onupdate_default_cols(self):
+ return dict(
+ (
+ table,
+ frozenset([
+ col.key for col in columns
+ if col.server_onupdate is not None])
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
@property
def selectable(self):
"""The :func:`.select` construct this :class:`.Mapper` selects from
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 768c1146a..88c96e94c 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -448,6 +448,7 @@ def _collect_update_commands(
set(propkey_to_col).intersection(state_dict).difference(
mapper._pk_keys_by_table[table])
)
+ has_all_defaults = True
else:
params = {}
for propkey in set(propkey_to_col).intersection(
@@ -463,6 +464,12 @@ def _collect_update_commands(
value, state.committed_state[propkey]) is not True:
params[col.key] = value
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_onupdate_default_cols[table].\
+ issubset(params)
+ else:
+ has_all_defaults = True
+
if update_version_id is not None and \
mapper.version_id_col in mapper._cols_by_table[table]:
@@ -529,7 +536,7 @@ def _collect_update_commands(
params.update(pk_params)
yield (
state, state_dict, params, mapper,
- connection, value_params)
+ connection, value_params, has_all_defaults)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -619,23 +626,20 @@ def _emit_update_statements(base_mapper, uowtransaction,
type_=mapper.version_id_col.type))
stmt = table.update(clause)
- if mapper.base_mapper.eager_defaults:
- stmt = stmt.return_defaults()
- elif mapper.version_id_col is not None:
- stmt = stmt.return_defaults(mapper.version_id_col)
-
return stmt
statement = base_mapper._memo(('update', table), update_stmt)
- for (connection, paramkeys, hasvalue), \
+ for (connection, paramkeys, hasvalue, has_all_defaults), \
records in groupby(
update,
lambda rec: (
rec[4], # connection
set(rec[2]), # set of parameter keys
- bool(rec[5]))): # whether or not we have "value" parameters
-
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6] # has_all_defaults
+ )
+ ):
rows = 0
records = list(records)
@@ -645,11 +649,16 @@ def _emit_update_statements(base_mapper, uowtransaction,
assert_singlerow = connection.dialect.supports_sane_rowcount
assert_multirow = assert_singlerow and \
connection.dialect.supports_sane_multi_rowcount
- allow_multirow = not needs_version_id
+ allow_multirow = has_all_defaults and not needs_version_id
+
+ if bookkeeping and mapper.base_mapper.eager_defaults:
+ statement = statement.return_defaults()
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
if hasvalue:
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
c = connection.execute(
statement.values(value_params),
params)
@@ -669,7 +678,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
c = cached_connections[connection].\
execute(statement, params)
@@ -699,7 +708,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows += c.rowcount
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
if bookkeeping:
_postfetch(
mapper,
@@ -741,6 +750,7 @@ def _emit_insert_statements(base_mapper, uowtransaction,
bool(rec[5]), # whether we have "value" parameters
rec[6],
rec[7])):
+
if not bookkeeping or \
(
has_all_defaults
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 18b96018d..c5495ccde 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -493,6 +493,7 @@ def _append_param_update(
else:
compiler.postfetch.append(c)
elif implicit_return_defaults and \
+ stmt._return_defaults is not True and \
c in implicit_return_defaults:
compiler.returning.append(c)
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 243493607..39d078985 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -13,6 +13,7 @@ import contextlib
from .. import event
from sqlalchemy.schema import _DDLCompiles
from sqlalchemy.engine.util import _distill_params
+from sqlalchemy.engine import url
class AssertRule(object):
@@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule):
class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params=None):
+ def __init__(self, statement, params=None, dialect='default'):
self.statement = statement
self.params = params
+ self.dialect = dialect
def _compare_sql(self, execute_observed, received_statement):
stmt = re.sub(r'[\n\t]', '', self.statement)
return received_statement == stmt
def _compile_dialect(self, execute_observed):
- return DefaultDialect()
+ if self.dialect == 'default':
+ return DefaultDialect()
+ else:
+ # ugh
+ if self.dialect == 'postgresql':
+ params = {'implicit_returning': True}
+ else:
+ params = {}
+ return url.URL(self.dialect).get_dialect()(**params)
def _received_statement(self, execute_observed):
"""reconstruct the statement and params in terms
@@ -159,7 +169,7 @@ class CompiledSQL(SQLMatchRule):
'Testing for compiled statement %r partial params %r, '
'received %%(received_statement)r with params '
'%%(received_parameters)r' % (
- self.statement, expected_params
+ self.statement.replace('%', '%%'), expected_params
)
)
@@ -170,6 +180,7 @@ class RegexSQL(CompiledSQL):
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
+ self.dialect = 'default'
def _failure_message(self, expected_params):
return (
diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py
index 09240dfdb..c8ce13c91 100644
--- a/test/orm/test_unitofworkv2.py
+++ b/test/orm/test_unitofworkv2.py
@@ -5,7 +5,8 @@ from sqlalchemy.testing.schema import Table, Column
from test.orm import _fixtures
from sqlalchemy import exc, util
from sqlalchemy.testing import fixtures, config
-from sqlalchemy import Integer, String, ForeignKey, func, literal
+from sqlalchemy import Integer, String, ForeignKey, func, \
+ literal, FetchedValue, text
from sqlalchemy.orm import mapper, relationship, backref, \
create_session, unitofwork, attributes,\
Session, exc as orm_exc
@@ -1848,6 +1849,450 @@ class NoAttrEventInFlushTest(fixtures.MappedTest):
eq_(t1.returning_val, 5)
+class EagerDefaultsTest(fixtures.MappedTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ 'test', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo', Integer, server_default="3")
+ )
+
+ Table(
+ 'test2', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo', Integer),
+ Column('bar', Integer, server_onupdate=FetchedValue())
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class Thing(cls.Basic):
+ pass
+
+ class Thing2(cls.Basic):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ Thing = cls.classes.Thing
+
+ mapper(Thing, cls.tables.test, eager_defaults=True)
+
+ Thing2 = cls.classes.Thing2
+
+ mapper(Thing2, cls.tables.test2, eager_defaults=True)
+
+ def test_insert_defaults_present(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1, foo=5),
+ Thing(id=2, foo=10)
+ )
+
+ s.add_all([t1, t2])
+
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (:id, :foo)",
+ [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}]
+ ),
+ )
+
+ def go():
+ eq_(t1.foo, 5)
+ eq_(t2.foo, 10)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_insert_defaults_present_as_expr(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1, foo=text("2 + 5")),
+ Thing(id=2, foo=text("5 + 5"))
+ )
+
+ s.add_all([t1, t2])
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) "
+ "RETURNING test.foo",
+ [{'id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) "
+ "RETURNING test.foo",
+ [{'id': 2}],
+ dialect='postgresql'
+ )
+ )
+
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
+ [{'id': 1}]
+ ),
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
+ [{'id': 2}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 2}]
+ ),
+ )
+
+ def go():
+ eq_(t1.foo, 7)
+ eq_(t2.foo, 10)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_insert_defaults_nonpresent(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1),
+ Thing(id=2)
+ )
+
+ s.add_all([t1, t2])
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+ [{'id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+ [{'id': 2}],
+ dialect='postgresql'
+ ),
+ )
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (:id)",
+ [{'id': 1}, {'id': 2}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 2}]
+ )
+ )
+
+ def test_update_defaults_nonpresent(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2, t3, t4 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3),
+ Thing2(id=3, foo=3, bar=4),
+ Thing2(id=4, foo=4, bar=5)
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ t1.foo = 5
+ t2.foo = 6
+ t2.bar = 10
+ t3.foo = 7
+ t4.foo = 8
+ t4.bar = 12
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 5, 'test2_id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 7, 'test2_id': 3}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 8, 'bar': 12, 'test2_id': 4}],
+ dialect='postgresql'
+ ),
+ )
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 5, 'test2_id': 1}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 7, 'test2_id': 3}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 8, 'bar': 12, 'test2_id': 4}],
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 3}]
+ )
+ )
+
+ def go():
+ eq_(t1.bar, 2)
+ eq_(t2.bar, 10)
+ eq_(t3.bar, 4)
+ eq_(t4.bar, 12)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_update_defaults_present_as_expr(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2, t3, t4 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3),
+ Thing2(id=3, foo=3, bar=4),
+ Thing2(id=4, foo=4, bar=5)
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ t1.foo = 5
+ t1.bar = text("1 + 1")
+ t2.foo = 6
+ t2.bar = 10
+ t3.foo = 7
+ t4.foo = 8
+ t4.bar = text("5 + 7")
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=1 + 1 "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 5, 'test2_id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 7, 'test2_id': 3}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=5 + 7 "
+ "WHERE test2.id = %(test2_id)s RETURNING test2.bar",
+ [{'foo': 8, 'test2_id': 4}],
+ dialect='postgresql'
+ ),
+ )
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=1 + 1 "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 5, 'test2_id': 1}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 7, 'test2_id': 3}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=5 + 7 "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 8, 'test2_id': 4}],
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 3}]
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 4}]
+ )
+ )
+
+ def go():
+ eq_(t1.bar, 2)
+ eq_(t2.bar, 10)
+ eq_(t3.bar, 4)
+ eq_(t4.bar, 12)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_insert_defaults_bulk_insert(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ mappings = [
+ {"id": 1},
+ {"id": 2}
+ ]
+
+ self.assert_sql_execution(
+ testing.db,
+ lambda: s.bulk_insert_mappings(Thing, mappings),
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (:id)",
+ [{'id': 1}, {'id': 2}]
+ )
+ )
+
+ def test_update_defaults_bulk_update(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2, t3, t4 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3),
+ Thing2(id=3, foo=3, bar=4),
+ Thing2(id=4, foo=4, bar=5)
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ mappings = [
+ {"id": 1, "foo": 5},
+ {"id": 2, "foo": 6, "bar": 10},
+ {"id": 3, "foo": 7},
+ {"id": 4, "foo": 8}
+ ]
+
+ self.assert_sql_execution(
+ testing.db,
+ lambda: s.bulk_update_mappings(Thing2, mappings),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 5, 'test2_id': 1}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 7, 'test2_id': 3}, {'foo': 8, 'test2_id': 4}]
+ )
+ )
+
+ def test_update_defaults_present(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3)
+ )
+
+ s.add_all([t1, t2])
+ s.flush()
+
+ t1.bar = 5
+ t2.bar = 10
+
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s",
+ [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ )
+ )
+
class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults):
"""test support for custom datatypes that return a non-__bool__ value
when compared via __eq__(), eg. ticket 3469"""
diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py
index f42069230..124053d47 100644
--- a/test/orm/test_versioning.py
+++ b/test/orm/test_versioning.py
@@ -894,19 +894,26 @@ class ServerVersioningTest(fixtures.MappedTest):
class Bar(cls.Basic):
pass
- def _fixture(self, expire_on_commit=True):
+ def _fixture(self, expire_on_commit=True, eager_defaults=False):
Foo, version_table = self.classes.Foo, self.tables.version_table
mapper(
Foo, version_table, version_id_col=version_table.c.version_id,
version_id_generator=False,
+ eager_defaults=eager_defaults
)
s1 = Session(expire_on_commit=expire_on_commit)
return s1
def test_insert_col(self):
- sess = self._fixture()
+ self._test_insert_col()
+
+ def test_insert_col_eager_defaults(self):
+ self._test_insert_col(eager_defaults=True)
+
+ def _test_insert_col(self, **kw):
+ sess = self._fixture(**kw)
f1 = self.classes.Foo(value='f1')
sess.add(f1)
@@ -935,7 +942,13 @@ class ServerVersioningTest(fixtures.MappedTest):
self.assert_sql_execution(testing.db, sess.flush, *statements)
def test_update_col(self):
- sess = self._fixture()
+ self._test_update_col()
+
+ def test_update_col_eager_defaults(self):
+ self._test_update_col(eager_defaults=True)
+
+ def _test_update_col(self, **kw):
+ sess = self._fixture(**kw)
f1 = self.classes.Foo(value='f1')
sess.add(f1)
diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py
index cd9f632b9..77a0c6007 100644
--- a/test/sql/test_returning.py
+++ b/test/sql/test_returning.py
@@ -387,6 +387,31 @@ class ReturnDefaultsTest(fixtures.TablesTest):
{"data": None, 'upddef': 1}
)
+ def test_insert_all(self):
+ t1 = self.tables.t1
+ result = testing.db.execute(
+ t1.insert().values(upddef=1).return_defaults()
+ )
+ eq_(
+ dict(result.returned_defaults),
+ {"id": 1, "data": None, "insdef": 0}
+ )
+
+ def test_update_all(self):
+ t1 = self.tables.t1
+ testing.db.execute(
+ t1.insert().values(upddef=1)
+ )
+ result = testing.db.execute(
+ t1.update().
+ values(insdef=2).return_defaults()
+ )
+ eq_(
+ dict(result.returned_defaults),
+ {'upddef': 1}
+ )
+
+
class ImplicitReturningFlag(fixtures.TestBase):
__backend__ = True