summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_12/4062.rst13
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py1
-rw-r--r--lib/sqlalchemy/engine/default.py4
-rw-r--r--lib/sqlalchemy/orm/persistence.py24
-rw-r--r--lib/sqlalchemy/testing/requirements.py23
-rw-r--r--test/requirements.py14
-rw-r--r--test/sql/test_rowcount.py10
7 files changed, 66 insertions, 23 deletions
diff --git a/doc/build/changelog/unreleased_12/4062.rst b/doc/build/changelog/unreleased_12/4062.rst
new file mode 100644
index 000000000..3a89a1ad6
--- /dev/null
+++ b/doc/build/changelog/unreleased_12/4062.rst
@@ -0,0 +1,13 @@
+.. change::
+ :tags: bug, mssql, orm
+ :tickets: 4062
+
+ Added a new class of "rowcount support" for dialects that is specific to
+ when "RETURNING", which on SQL Server looks like "OUTPUT inserted", is in
+ use, as the PyODBC backend isn't able to give us rowcount on an UPDATE or
+ DELETE statement when OUTPUT is in effect. This primarily affects the ORM
+ when a flush is updating a row that contains server-calcluated values,
+ raising an error if the backend does not return the expected row count.
+ PyODBC now states that it supports rowcount except if OUTPUT.inserted is
+ present, which is taken into account by the ORM during a flush as to
+ whether it will look for a rowcount.
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index 65fe37212..66acf0072 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -16,6 +16,7 @@ import re
class PyODBCConnector(Connector):
driver = 'pyodbc'
+ supports_sane_rowcount_returning = False
supports_sane_multi_rowcount = False
if util.py2k:
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index d1b54ab01..8b72c0001 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -249,6 +249,10 @@ class DefaultDialect(interfaces.Dialect):
def dialect_description(self):
return self.name + "+" + self.driver
+ @property
+ def supports_sane_rowcount_returning(self):
+ return self.supports_sane_rowcount
+
@classmethod
def get_pool_class(cls, url):
return getattr(cls, 'poolclass', pool.QueuePool)
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index faacd018e..24c9743d4 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -693,22 +693,28 @@ def _emit_update_statements(base_mapper, uowtransaction,
records = list(records)
statement = cached_stmt
-
- # TODO: would be super-nice to not have to determine this boolean
- # inside the loop here, in the 99.9999% of the time there's only
- # one connection in use
- assert_singlerow = connection.dialect.supports_sane_rowcount
- assert_multirow = assert_singlerow and \
- connection.dialect.supports_sane_multi_rowcount
- allow_multirow = has_all_defaults and not needs_version_id
+ return_defaults = False
if not has_all_pks:
statement = statement.return_defaults()
+ return_defaults = True
elif bookkeeping and not has_all_defaults and \
mapper.base_mapper.eager_defaults:
statement = statement.return_defaults()
+ return_defaults = True
elif mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
+ return_defaults = True
+
+ assert_singlerow = (
+ connection.dialect.supports_sane_rowcount
+ if not return_defaults
+ else connection.dialect.supports_sane_rowcount_returning
+ )
+
+ assert_multirow = assert_singlerow and \
+ connection.dialect.supports_sane_multi_rowcount
+ allow_multirow = has_all_defaults and not needs_version_id
if hasvalue:
for state, state_dict, params, mapper, \
@@ -728,7 +734,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
c.context.compiled_parameters[0],
value_params)
rows += c.rowcount
- check_rowcount = True
+ check_rowcount = assert_singlerow
else:
if not allow_multirow:
check_rowcount = assert_singlerow
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 08a7b1ced..327362bf6 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -193,6 +193,29 @@ class SuiteRequirements(Requirements):
return exclusions.open()
+
+ @property
+ def sane_rowcount(self):
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_sane_rowcount,
+ "driver doesn't support 'sane' rowcount"
+ )
+
+ @property
+ def sane_multi_rowcount(self):
+ return exclusions.fails_if(
+ lambda config: not config.db.dialect.supports_sane_multi_rowcount,
+ "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
+ )
+
+ @property
+ def sane_rowcount_w_returning(self):
+ return exclusions.fails_if(
+ lambda config:
+ not config.db.dialect.supports_sane_rowcount_returning,
+ "driver doesn't support 'sane' rowcount when returning is on"
+ )
+
@property
def empty_inserts(self):
"""target platform supports INSERT with no values, i.e.
diff --git a/test/requirements.py b/test/requirements.py
index 4f01eac9b..0362e28d1 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -559,13 +559,6 @@ class DefaultRequirements(SuiteRequirements):
])
@property
- def sane_rowcount(self):
- return skip_if(
- lambda config: not config.db.dialect.supports_sane_rowcount,
- "driver doesn't support 'sane' rowcount"
- )
-
- @property
def emulated_lastrowid(self):
""""target dialect retrieves cursor.lastrowid or an equivalent
after an insert() construct executes.
@@ -594,13 +587,6 @@ class DefaultRequirements(SuiteRequirements):
'sqlite+pysqlcipher')
@property
- def sane_multi_rowcount(self):
- return fails_if(
- lambda config: not config.db.dialect.supports_sane_multi_rowcount,
- "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
- )
-
- @property
def nullsordering(self):
"""Target backends that support nulls ordering."""
return fails_on_everything_except('postgresql', 'oracle', 'firebird')
diff --git a/test/sql/test_rowcount.py b/test/sql/test_rowcount.py
index 16087b94c..3399ba7ec 100644
--- a/test/sql/test_rowcount.py
+++ b/test/sql/test_rowcount.py
@@ -65,6 +65,15 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
r = employees_table.update(department == 'C').execute(department='C')
assert r.rowcount == 3
+ @testing.requires.sane_rowcount_w_returning
+ def test_update_rowcount_return_defaults(self):
+ department = employees_table.c.department
+ stmt = employees_table.update(department == 'C').values(
+ name=employees_table.c.department + 'Z').return_defaults()
+
+ r = stmt.execute()
+ assert r.rowcount == 3
+
def test_raw_sql_rowcount(self):
# test issue #3622, make sure eager rowcount is called for text
with testing.db.connect() as conn:
@@ -117,3 +126,4 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
eq_(
r.rowcount, 2
)
+