summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mssql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py53
1 files changed, 34 insertions, 19 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 4339551a3..a3855cc2c 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -685,7 +685,6 @@ import operator
import re
from . import information_schema as ischema
-from ... import engine
from ... import exc
from ... import schema as sa_schema
from ... import sql
@@ -693,6 +692,7 @@ from ... import types as sqltypes
from ... import util
from ...engine import default
from ...engine import reflection
+from ...engine import result as _result
from ...sql import compiler
from ...sql import elements
from ...sql import expression
@@ -1431,8 +1431,9 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
class MSExecutionContext(default.DefaultExecutionContext):
_enable_identity_insert = False
_select_lastrowid = False
- _result_proxy = None
_lastrowid = None
+ _rowcount = None
+ _result_strategy = None
def _opt_encode(self, statement):
if not self.dialect.supports_unicode_statements:
@@ -1500,6 +1501,10 @@ class MSExecutionContext(default.DefaultExecutionContext):
"""Disable IDENTITY_INSERT if enabled."""
conn = self.root_connection
+
+ if self.isinsert or self.isupdate or self.isdelete:
+ self._rowcount = self.cursor.rowcount
+
if self._select_lastrowid:
if self.dialect.use_scope_identity:
conn._cursor_execute(
@@ -1516,10 +1521,13 @@ class MSExecutionContext(default.DefaultExecutionContext):
row = self.cursor.fetchall()[0]
self._lastrowid = int(row[0])
- if (
+ elif (
self.isinsert or self.isupdate or self.isdelete
) and self.compiled.returning:
- self._result_proxy = engine.FullyBufferedResultProxy(self)
+ fbcr = _result.FullyBufferedCursorFetchStrategy
+ self._result_strategy = fbcr.create_from_buffer(
+ self.cursor, self.cursor.description, self.cursor.fetchall()
+ )
if self._enable_identity_insert:
conn._cursor_execute(
@@ -1537,6 +1545,13 @@ class MSExecutionContext(default.DefaultExecutionContext):
def get_lastrowid(self):
return self._lastrowid
+ @property
+ def rowcount(self):
+ if self._rowcount is not None:
+ return self._rowcount
+ else:
+ return self.cursor.rowcount
+
def handle_dbapi_exception(self, e):
if self._enable_identity_insert:
try:
@@ -1551,11 +1566,13 @@ class MSExecutionContext(default.DefaultExecutionContext):
except Exception:
pass
- def get_result_proxy(self):
- if self._result_proxy:
- return self._result_proxy
+ def get_result_cursor_strategy(self, result):
+ if self._result_strategy:
+ return self._result_strategy
else:
- return engine.ResultProxy(self)
+ return super(MSExecutionContext, self).get_result_cursor_strategy(
+ result
+ )
class MSSQLCompiler(compiler.SQLCompiler):
@@ -2570,7 +2587,7 @@ class MSDialect(default.DefaultDialect):
if self.server_version_info < MS_2005_VERSION:
return []
- rp = connection.execute(
+ rp = connection.execution_options(future_result=True).execute(
sql.text(
"select ind.index_id, ind.is_unique, ind.name "
"from sys.indexes as ind join sys.tables as tab on "
@@ -2587,13 +2604,13 @@ class MSDialect(default.DefaultDialect):
.columns(name=sqltypes.Unicode())
)
indexes = {}
- for row in rp:
+ for row in rp.mappings():
indexes[row["index_id"]] = {
"name": row["name"],
"unique": row["is_unique"] == 1,
"column_names": [],
}
- rp = connection.execute(
+ rp = connection.execution_options(future_result=True).execute(
sql.text(
"select ind_col.index_id, ind_col.object_id, col.name "
"from sys.columns as col "
@@ -2611,7 +2628,7 @@ class MSDialect(default.DefaultDialect):
)
.columns(name=sqltypes.Unicode())
)
- for row in rp:
+ for row in rp.mappings():
if row["index_id"] in indexes:
indexes[row["index_id"]]["column_names"].append(row["name"])
@@ -2657,12 +2674,10 @@ class MSDialect(default.DefaultDialect):
[columns], whereclause, order_by=[columns.c.ordinal_position]
)
- c = connection.execute(s)
+ c = connection.execution_options(future_result=True).execute(s)
cols = []
- while True:
- row = c.fetchone()
- if row is None:
- break
+
+ for row in c.mappings():
(
name,
type_,
@@ -2785,9 +2800,9 @@ class MSDialect(default.DefaultDialect):
C.c.table_schema == owner,
),
)
- c = connection.execute(s)
+ c = connection.execution_options(future_result=True).execute(s)
constraint_name = None
- for row in c:
+ for row in c.mappings():
if "PRIMARY" in row[TC.c.constraint_type.name]:
pkeys.append(row[0])
if constraint_name is None: