summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mssql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py1238
1 files changed, 771 insertions, 467 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 9269225d3..161297015 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -655,9 +655,22 @@ from ...sql import compiler, expression, util as sql_util, quoted_name
from ... import engine
from ...engine import reflection, default
from ... import types as sqltypes
-from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \
- FLOAT, DATETIME, DATE, BINARY, \
- TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR
+from ...types import (
+ INTEGER,
+ BIGINT,
+ SMALLINT,
+ DECIMAL,
+ NUMERIC,
+ FLOAT,
+ DATETIME,
+ DATE,
+ BINARY,
+ TEXT,
+ VARCHAR,
+ NVARCHAR,
+ CHAR,
+ NCHAR,
+)
from ...util import update_wrapper
@@ -672,48 +685,202 @@ MS_2005_VERSION = (9,)
MS_2000_VERSION = (8,)
RESERVED_WORDS = set(
- ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization',
- 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade',
- 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce',
- 'collate', 'column', 'commit', 'compute', 'constraint', 'contains',
- 'containstable', 'continue', 'convert', 'create', 'cross', 'current',
- 'current_date', 'current_time', 'current_timestamp', 'current_user',
- 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default',
- 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double',
- 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec',
- 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor',
- 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full',
- 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity',
- 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert',
- 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like',
- 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not',
- 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource',
- 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer',
- 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print',
- 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext',
- 'reconfigure', 'references', 'replication', 'restore', 'restrict',
- 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount',
- 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select',
- 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics',
- 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top',
- 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union',
- 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',
- 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',
- 'writetext',
- ])
+ [
+ "add",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "authorization",
+ "backup",
+ "begin",
+ "between",
+ "break",
+ "browse",
+ "bulk",
+ "by",
+ "cascade",
+ "case",
+ "check",
+ "checkpoint",
+ "close",
+ "clustered",
+ "coalesce",
+ "collate",
+ "column",
+ "commit",
+ "compute",
+ "constraint",
+ "contains",
+ "containstable",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "dbcc",
+ "deallocate",
+ "declare",
+ "default",
+ "delete",
+ "deny",
+ "desc",
+ "disk",
+ "distinct",
+ "distributed",
+ "double",
+ "drop",
+ "dump",
+ "else",
+ "end",
+ "errlvl",
+ "escape",
+ "except",
+ "exec",
+ "execute",
+ "exists",
+ "exit",
+ "external",
+ "fetch",
+ "file",
+ "fillfactor",
+ "for",
+ "foreign",
+ "freetext",
+ "freetexttable",
+ "from",
+ "full",
+ "function",
+ "goto",
+ "grant",
+ "group",
+ "having",
+ "holdlock",
+ "identity",
+ "identity_insert",
+ "identitycol",
+ "if",
+ "in",
+ "index",
+ "inner",
+ "insert",
+ "intersect",
+ "into",
+ "is",
+ "join",
+ "key",
+ "kill",
+ "left",
+ "like",
+ "lineno",
+ "load",
+ "merge",
+ "national",
+ "nocheck",
+ "nonclustered",
+ "not",
+ "null",
+ "nullif",
+ "of",
+ "off",
+ "offsets",
+ "on",
+ "open",
+ "opendatasource",
+ "openquery",
+ "openrowset",
+ "openxml",
+ "option",
+ "or",
+ "order",
+ "outer",
+ "over",
+ "percent",
+ "pivot",
+ "plan",
+ "precision",
+ "primary",
+ "print",
+ "proc",
+ "procedure",
+ "public",
+ "raiserror",
+ "read",
+ "readtext",
+ "reconfigure",
+ "references",
+ "replication",
+ "restore",
+ "restrict",
+ "return",
+ "revert",
+ "revoke",
+ "right",
+ "rollback",
+ "rowcount",
+ "rowguidcol",
+ "rule",
+ "save",
+ "schema",
+ "securityaudit",
+ "select",
+ "session_user",
+ "set",
+ "setuser",
+ "shutdown",
+ "some",
+ "statistics",
+ "system_user",
+ "table",
+ "tablesample",
+ "textsize",
+ "then",
+ "to",
+ "top",
+ "tran",
+ "transaction",
+ "trigger",
+ "truncate",
+ "tsequal",
+ "union",
+ "unique",
+ "unpivot",
+ "update",
+ "updatetext",
+ "use",
+ "user",
+ "values",
+ "varying",
+ "view",
+ "waitfor",
+ "when",
+ "where",
+ "while",
+ "with",
+ "writetext",
+ ]
+)
class REAL(sqltypes.REAL):
- __visit_name__ = 'REAL'
+ __visit_name__ = "REAL"
def __init__(self, **kw):
# REAL is a synonym for FLOAT(24) on SQL server
- kw['precision'] = 24
+ kw["precision"] = 24
super(REAL, self).__init__(**kw)
class TINYINT(sqltypes.Integer):
- __visit_name__ = 'TINYINT'
+ __visit_name__ = "TINYINT"
# MSSQL DATE/TIME types have varied behavior, sometimes returning
@@ -721,14 +888,15 @@ class TINYINT(sqltypes.Integer):
# filter bind parameters into datetime objects (required by pyodbc,
# not sure about other dialects).
-class _MSDate(sqltypes.Date):
+class _MSDate(sqltypes.Date):
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
+
return process
_reg = re.compile(r"(\d+)-(\d+)-(\d+)")
@@ -741,18 +909,16 @@ class _MSDate(sqltypes.Date):
m = self._reg.match(value)
if not m:
raise ValueError(
- "could not parse %r as a date value" % (value, ))
- return datetime.date(*[
- int(x or 0)
- for x in m.groups()
- ])
+ "could not parse %r as a date value" % (value,)
+ )
+ return datetime.date(*[int(x or 0) for x in m.groups()])
else:
return value
+
return process
class TIME(sqltypes.TIME):
-
def __init__(self, precision=None, **kwargs):
self.precision = precision
super(TIME, self).__init__()
@@ -763,10 +929,12 @@ class TIME(sqltypes.TIME):
def process(value):
if isinstance(value, datetime.datetime):
value = datetime.datetime.combine(
- self.__zero_date, value.time())
+ self.__zero_date, value.time()
+ )
elif isinstance(value, datetime.time):
value = datetime.datetime.combine(self.__zero_date, value)
return value
+
return process
_reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?")
@@ -779,24 +947,26 @@ class TIME(sqltypes.TIME):
m = self._reg.match(value)
if not m:
raise ValueError(
- "could not parse %r as a time value" % (value, ))
- return datetime.time(*[
- int(x or 0)
- for x in m.groups()])
+ "could not parse %r as a time value" % (value,)
+ )
+ return datetime.time(*[int(x or 0) for x in m.groups()])
else:
return value
+
return process
+
+
_MSTime = TIME
class _DateTimeBase(object):
-
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
+
return process
@@ -805,11 +975,11 @@ class _MSDateTime(_DateTimeBase, sqltypes.DateTime):
class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
- __visit_name__ = 'SMALLDATETIME'
+ __visit_name__ = "SMALLDATETIME"
class DATETIME2(_DateTimeBase, sqltypes.DateTime):
- __visit_name__ = 'DATETIME2'
+ __visit_name__ = "DATETIME2"
def __init__(self, precision=None, **kw):
super(DATETIME2, self).__init__(**kw)
@@ -818,7 +988,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime):
# TODO: is this not an Interval ?
class DATETIMEOFFSET(sqltypes.TypeEngine):
- __visit_name__ = 'DATETIMEOFFSET'
+ __visit_name__ = "DATETIMEOFFSET"
def __init__(self, precision=None, **kwargs):
self.precision = precision
@@ -847,7 +1017,7 @@ class TIMESTAMP(sqltypes._Binary):
"""
- __visit_name__ = 'TIMESTAMP'
+ __visit_name__ = "TIMESTAMP"
# expected by _Binary to be present
length = None
@@ -866,12 +1036,14 @@ class TIMESTAMP(sqltypes._Binary):
def result_processor(self, dialect, coltype):
super_ = super(TIMESTAMP, self).result_processor(dialect, coltype)
if self.convert_int:
+
def process(value):
value = super_(value)
if value is not None:
# https://stackoverflow.com/a/30403242/34549
- value = int(codecs.encode(value, 'hex'), 16)
+ value = int(codecs.encode(value, "hex"), 16)
return value
+
return process
else:
return super_
@@ -898,7 +1070,7 @@ class ROWVERSION(TIMESTAMP):
"""
- __visit_name__ = 'ROWVERSION'
+ __visit_name__ = "ROWVERSION"
class NTEXT(sqltypes.UnicodeText):
@@ -906,7 +1078,7 @@ class NTEXT(sqltypes.UnicodeText):
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
characters."""
- __visit_name__ = 'NTEXT'
+ __visit_name__ = "NTEXT"
class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):
@@ -925,11 +1097,12 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):
"""
- __visit_name__ = 'VARBINARY'
+
+ __visit_name__ = "VARBINARY"
class IMAGE(sqltypes.LargeBinary):
- __visit_name__ = 'IMAGE'
+ __visit_name__ = "IMAGE"
class XML(sqltypes.Text):
@@ -943,19 +1116,20 @@ class XML(sqltypes.Text):
.. versionadded:: 1.1.11
"""
- __visit_name__ = 'XML'
+
+ __visit_name__ = "XML"
class BIT(sqltypes.TypeEngine):
- __visit_name__ = 'BIT'
+ __visit_name__ = "BIT"
class MONEY(sqltypes.TypeEngine):
- __visit_name__ = 'MONEY'
+ __visit_name__ = "MONEY"
class SMALLMONEY(sqltypes.TypeEngine):
- __visit_name__ = 'SMALLMONEY'
+ __visit_name__ = "SMALLMONEY"
class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
@@ -963,7 +1137,8 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
class SQL_VARIANT(sqltypes.TypeEngine):
- __visit_name__ = 'SQL_VARIANT'
+ __visit_name__ = "SQL_VARIANT"
+
# old names.
MSDateTime = _MSDateTime
@@ -990,36 +1165,36 @@ MSUniqueIdentifier = UNIQUEIDENTIFIER
MSVariant = SQL_VARIANT
ischema_names = {
- 'int': INTEGER,
- 'bigint': BIGINT,
- 'smallint': SMALLINT,
- 'tinyint': TINYINT,
- 'varchar': VARCHAR,
- 'nvarchar': NVARCHAR,
- 'char': CHAR,
- 'nchar': NCHAR,
- 'text': TEXT,
- 'ntext': NTEXT,
- 'decimal': DECIMAL,
- 'numeric': NUMERIC,
- 'float': FLOAT,
- 'datetime': DATETIME,
- 'datetime2': DATETIME2,
- 'datetimeoffset': DATETIMEOFFSET,
- 'date': DATE,
- 'time': TIME,
- 'smalldatetime': SMALLDATETIME,
- 'binary': BINARY,
- 'varbinary': VARBINARY,
- 'bit': BIT,
- 'real': REAL,
- 'image': IMAGE,
- 'xml': XML,
- 'timestamp': TIMESTAMP,
- 'money': MONEY,
- 'smallmoney': SMALLMONEY,
- 'uniqueidentifier': UNIQUEIDENTIFIER,
- 'sql_variant': SQL_VARIANT,
+ "int": INTEGER,
+ "bigint": BIGINT,
+ "smallint": SMALLINT,
+ "tinyint": TINYINT,
+ "varchar": VARCHAR,
+ "nvarchar": NVARCHAR,
+ "char": CHAR,
+ "nchar": NCHAR,
+ "text": TEXT,
+ "ntext": NTEXT,
+ "decimal": DECIMAL,
+ "numeric": NUMERIC,
+ "float": FLOAT,
+ "datetime": DATETIME,
+ "datetime2": DATETIME2,
+ "datetimeoffset": DATETIMEOFFSET,
+ "date": DATE,
+ "time": TIME,
+ "smalldatetime": SMALLDATETIME,
+ "binary": BINARY,
+ "varbinary": VARBINARY,
+ "bit": BIT,
+ "real": REAL,
+ "image": IMAGE,
+ "xml": XML,
+ "timestamp": TIMESTAMP,
+ "money": MONEY,
+ "smallmoney": SMALLMONEY,
+ "uniqueidentifier": UNIQUEIDENTIFIER,
+ "sql_variant": SQL_VARIANT,
}
@@ -1030,8 +1205,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
"""
- if getattr(type_, 'collation', None):
- collation = 'COLLATE %s' % type_.collation
+ if getattr(type_, "collation", None):
+ collation = "COLLATE %s" % type_.collation
else:
collation = None
@@ -1041,15 +1216,14 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
if length:
spec = spec + "(%s)" % length
- return ' '.join([c for c in (spec, collation)
- if c is not None])
+ return " ".join([c for c in (spec, collation) if c is not None])
def visit_FLOAT(self, type_, **kw):
- precision = getattr(type_, 'precision', None)
+ precision = getattr(type_, "precision", None)
if precision is None:
return "FLOAT"
else:
- return "FLOAT(%(precision)s)" % {'precision': precision}
+ return "FLOAT(%(precision)s)" % {"precision": precision}
def visit_TINYINT(self, type_, **kw):
return "TINYINT"
@@ -1061,7 +1235,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "DATETIMEOFFSET"
def visit_TIME(self, type_, **kw):
- precision = getattr(type_, 'precision', None)
+ precision = getattr(type_, "precision", None)
if precision is not None:
return "TIME(%s)" % precision
else:
@@ -1074,7 +1248,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "ROWVERSION"
def visit_DATETIME2(self, type_, **kw):
- precision = getattr(type_, 'precision', None)
+ precision = getattr(type_, "precision", None)
if precision is not None:
return "DATETIME2(%s)" % precision
else:
@@ -1105,7 +1279,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return self._extend("TEXT", type_)
def visit_VARCHAR(self, type_, **kw):
- return self._extend("VARCHAR", type_, length=type_.length or 'max')
+ return self._extend("VARCHAR", type_, length=type_.length or "max")
def visit_CHAR(self, type_, **kw):
return self._extend("CHAR", type_)
@@ -1114,7 +1288,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return self._extend("NCHAR", type_)
def visit_NVARCHAR(self, type_, **kw):
- return self._extend("NVARCHAR", type_, length=type_.length or 'max')
+ return self._extend("NVARCHAR", type_, length=type_.length or "max")
def visit_date(self, type_, **kw):
if self.dialect.server_version_info < MS_2008_VERSION:
@@ -1141,10 +1315,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "XML"
def visit_VARBINARY(self, type_, **kw):
- return self._extend(
- "VARBINARY",
- type_,
- length=type_.length or 'max')
+ return self._extend("VARBINARY", type_, length=type_.length or "max")
def visit_boolean(self, type_, **kw):
return self.visit_BIT(type_)
@@ -1156,13 +1327,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
return "MONEY"
def visit_SMALLMONEY(self, type_, **kw):
- return 'SMALLMONEY'
+ return "SMALLMONEY"
def visit_UNIQUEIDENTIFIER(self, type_, **kw):
return "UNIQUEIDENTIFIER"
def visit_SQL_VARIANT(self, type_, **kw):
- return 'SQL_VARIANT'
+ return "SQL_VARIANT"
class MSExecutionContext(default.DefaultExecutionContext):
@@ -1186,41 +1357,44 @@ class MSExecutionContext(default.DefaultExecutionContext):
insert_has_sequence = seq_column is not None
if insert_has_sequence:
- self._enable_identity_insert = \
- seq_column.key in self.compiled_parameters[0] or \
- (
- self.compiled.statement.parameters and (
- (
- self.compiled.statement._has_multi_parameters
- and
- seq_column.key in
- self.compiled.statement.parameters[0]
- ) or (
- not
- self.compiled.statement._has_multi_parameters
- and
- seq_column.key in
- self.compiled.statement.parameters
- )
+ self._enable_identity_insert = seq_column.key in self.compiled_parameters[
+ 0
+ ] or (
+ self.compiled.statement.parameters
+ and (
+ (
+ self.compiled.statement._has_multi_parameters
+ and seq_column.key
+ in self.compiled.statement.parameters[0]
+ )
+ or (
+ not self.compiled.statement._has_multi_parameters
+ and seq_column.key
+ in self.compiled.statement.parameters
)
)
+ )
else:
self._enable_identity_insert = False
- self._select_lastrowid = not self.compiled.inline and \
- insert_has_sequence and \
- not self.compiled.returning and \
- not self._enable_identity_insert and \
- not self.executemany
+ self._select_lastrowid = (
+ not self.compiled.inline
+ and insert_has_sequence
+ and not self.compiled.returning
+ and not self._enable_identity_insert
+ and not self.executemany
+ )
if self._enable_identity_insert:
self.root_connection._cursor_execute(
self.cursor,
self._opt_encode(
- "SET IDENTITY_INSERT %s ON" %
- self.dialect.identifier_preparer.format_table(tbl)),
+ "SET IDENTITY_INSERT %s ON"
+ % self.dialect.identifier_preparer.format_table(tbl)
+ ),
(),
- self)
+ self,
+ )
def post_exec(self):
"""Disable IDENTITY_INSERT if enabled."""
@@ -1230,29 +1404,35 @@ class MSExecutionContext(default.DefaultExecutionContext):
if self.dialect.use_scope_identity:
conn._cursor_execute(
self.cursor,
- "SELECT scope_identity() AS lastrowid", (), self)
+ "SELECT scope_identity() AS lastrowid",
+ (),
+ self,
+ )
else:
- conn._cursor_execute(self.cursor,
- "SELECT @@identity AS lastrowid",
- (),
- self)
+ conn._cursor_execute(
+ self.cursor, "SELECT @@identity AS lastrowid", (), self
+ )
# fetchall() ensures the cursor is consumed without closing it
row = self.cursor.fetchall()[0]
self._lastrowid = int(row[0])
- if (self.isinsert or self.isupdate or self.isdelete) and \
- self.compiled.returning:
+ if (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and self.compiled.returning:
self._result_proxy = engine.FullyBufferedResultProxy(self)
if self._enable_identity_insert:
conn._cursor_execute(
self.cursor,
self._opt_encode(
- "SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer. format_table(
- self.compiled.statement.table)),
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(
+ self.compiled.statement.table
+ )
+ ),
(),
- self)
+ self,
+ )
def get_lastrowid(self):
return self._lastrowid
@@ -1262,9 +1442,12 @@ class MSExecutionContext(default.DefaultExecutionContext):
try:
self.cursor.execute(
self._opt_encode(
- "SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer. format_table(
- self.compiled.statement.table)))
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(
+ self.compiled.statement.table
+ )
+ )
+ )
except Exception:
pass
@@ -1281,11 +1464,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
extract_map = util.update_copy(
compiler.SQLCompiler.extract_map,
{
- 'doy': 'dayofyear',
- 'dow': 'weekday',
- 'milliseconds': 'millisecond',
- 'microseconds': 'microsecond'
- })
+ "doy": "dayofyear",
+ "dow": "weekday",
+ "milliseconds": "millisecond",
+ "microseconds": "microsecond",
+ },
+ )
def __init__(self, *args, **kwargs):
self.tablealiases = {}
@@ -1298,6 +1482,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
else:
super_ = getattr(super(MSSQLCompiler, self), fn.__name__)
return super_(*arg, **kw)
+
return decorate
def visit_now_func(self, fn, **kw):
@@ -1313,20 +1498,22 @@ class MSSQLCompiler(compiler.SQLCompiler):
return "LEN%s" % self.function_argspec(fn, **kw)
def visit_concat_op_binary(self, binary, operator, **kw):
- return "%s + %s" % \
- (self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ return "%s + %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
def visit_true(self, expr, **kw):
- return '1'
+ return "1"
def visit_false(self, expr, **kw):
- return '0'
+ return "0"
def visit_match_op_binary(self, binary, operator, **kw):
return "CONTAINS (%s, %s)" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.right, **kw),
+ )
def get_select_precolumns(self, select, **kw):
""" MS-SQL puts TOP, it's version of LIMIT here """
@@ -1345,7 +1532,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
return s
else:
return compiler.SQLCompiler.get_select_precolumns(
- self, select, **kw)
+ self, select, **kw
+ )
def get_from_hint_text(self, table, text):
return text
@@ -1363,20 +1551,21 @@ class MSSQLCompiler(compiler.SQLCompiler):
"""
if (
- (
- not select._simple_int_limit and
- select._limit_clause is not None
- ) or (
- select._offset_clause is not None and
- not select._simple_int_offset or select._offset
+ (not select._simple_int_limit and select._limit_clause is not None)
+ or (
+ select._offset_clause is not None
+ and not select._simple_int_offset
+ or select._offset
)
- ) and not getattr(select, '_mssql_visit', None):
+ ) and not getattr(select, "_mssql_visit", None):
# to use ROW_NUMBER(), an ORDER BY is required.
if not select._order_by_clause.clauses:
- raise exc.CompileError('MSSQL requires an order_by when '
- 'using an OFFSET or a non-simple '
- 'LIMIT clause')
+ raise exc.CompileError(
+ "MSSQL requires an order_by when "
+ "using an OFFSET or a non-simple "
+ "LIMIT clause"
+ )
_order_by_clauses = [
sql_util.unwrap_label_reference(elem)
@@ -1385,24 +1574,31 @@ class MSSQLCompiler(compiler.SQLCompiler):
limit_clause = select._limit_clause
offset_clause = select._offset_clause
- kwargs['select_wraps_for'] = select
+ kwargs["select_wraps_for"] = select
select = select._generate()
select._mssql_visit = True
- select = select.column(
- sql.func.ROW_NUMBER().over(order_by=_order_by_clauses)
- .label("mssql_rn")).order_by(None).alias()
+ select = (
+ select.column(
+ sql.func.ROW_NUMBER()
+ .over(order_by=_order_by_clauses)
+ .label("mssql_rn")
+ )
+ .order_by(None)
+ .alias()
+ )
- mssql_rn = sql.column('mssql_rn')
- limitselect = sql.select([c for c in select.c if
- c.key != 'mssql_rn'])
+ mssql_rn = sql.column("mssql_rn")
+ limitselect = sql.select(
+ [c for c in select.c if c.key != "mssql_rn"]
+ )
if offset_clause is not None:
limitselect.append_whereclause(mssql_rn > offset_clause)
if limit_clause is not None:
limitselect.append_whereclause(
- mssql_rn <= (limit_clause + offset_clause))
+ mssql_rn <= (limit_clause + offset_clause)
+ )
else:
- limitselect.append_whereclause(
- mssql_rn <= (limit_clause))
+ limitselect.append_whereclause(mssql_rn <= (limit_clause))
return self.process(limitselect, **kwargs)
else:
return compiler.SQLCompiler.visit_select(self, select, **kwargs)
@@ -1422,35 +1618,38 @@ class MSSQLCompiler(compiler.SQLCompiler):
@_with_legacy_schema_aliasing
def visit_alias(self, alias, **kw):
# translate for schema-qualified table aliases
- kw['mssql_aliased'] = alias.original
+ kw["mssql_aliased"] = alias.original
return super(MSSQLCompiler, self).visit_alias(alias, **kw)
@_with_legacy_schema_aliasing
def visit_column(self, column, add_to_result_map=None, **kw):
- if column.table is not None and \
- (not self.isupdate and not self.isdelete) or \
- self.is_subquery():
+ if (
+ column.table is not None
+ and (not self.isupdate and not self.isdelete)
+ or self.is_subquery()
+ ):
# translate for schema-qualified table aliases
t = self._schema_aliased_table(column.table)
if t is not None:
converted = expression._corresponding_column_or_error(
- t, column)
+ t, column
+ )
if add_to_result_map is not None:
add_to_result_map(
column.name,
column.name,
(column, column.name, column.key),
- column.type
+ column.type,
)
- return super(MSSQLCompiler, self).\
- visit_column(converted, **kw)
+ return super(MSSQLCompiler, self).visit_column(converted, **kw)
return super(MSSQLCompiler, self).visit_column(
- column, add_to_result_map=add_to_result_map, **kw)
+ column, add_to_result_map=add_to_result_map, **kw
+ )
def _schema_aliased_table(self, table):
- if getattr(table, 'schema', None) is not None:
+ if getattr(table, "schema", None) is not None:
if table not in self.tablealiases:
self.tablealiases[table] = table.alias()
return self.tablealiases[table]
@@ -1459,16 +1658,17 @@ class MSSQLCompiler(compiler.SQLCompiler):
def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
- return 'DATEPART(%s, %s)' % \
- (field, self.process(extract.expr, **kw))
+ return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw))
def visit_savepoint(self, savepoint_stmt):
- return "SAVE TRANSACTION %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_rollback_to_savepoint(self, savepoint_stmt):
- return ("ROLLBACK TRANSACTION %s"
- % self.preparer.format_savepoint(savepoint_stmt))
+ return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_binary(self, binary, **kwargs):
"""Move bind parameters to the right-hand side of an operator, where
@@ -1481,10 +1681,11 @@ class MSSQLCompiler(compiler.SQLCompiler):
and not isinstance(binary.right, expression.BindParameter)
):
return self.process(
- expression.BinaryExpression(binary.right,
- binary.left,
- binary.operator),
- **kwargs)
+ expression.BinaryExpression(
+ binary.right, binary.left, binary.operator
+ ),
+ **kwargs
+ )
return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
def returning_clause(self, stmt, returning_cols):
@@ -1497,12 +1698,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
adapter = sql_util.ClauseAdapter(target)
columns = [
- self._label_select_column(None, adapter.traverse(c),
- True, False, {})
+ self._label_select_column(
+ None, adapter.traverse(c), True, False, {}
+ )
for c in expression._select_iterables(returning_cols)
]
- return 'OUTPUT ' + ', '.join(columns)
+ return "OUTPUT " + ", ".join(columns)
def get_cte_preamble(self, recursive):
# SQL Server finds it too inconvenient to accept
@@ -1515,13 +1717,14 @@ class MSSQLCompiler(compiler.SQLCompiler):
if isinstance(column, expression.Function):
return column.label(None)
else:
- return super(MSSQLCompiler, self).\
- label_select_column(select, column, asfrom)
+ return super(MSSQLCompiler, self).label_select_column(
+ select, column, asfrom
+ )
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which
# SQLAlchemy doesn't use
- return ''
+ return ""
def order_by_clause(self, select, **kw):
order_by = self.process(select._order_by_clause, **kw)
@@ -1532,10 +1735,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
else:
return ""
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the UPDATE..FROM clause specific to MSSQL.
In MSSQL, if the UPDATE statement involves an alias of the table to
@@ -1543,13 +1745,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
well. Otherwise, it is optional. Here, we add it regardless.
"""
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
"""If we have extra froms make sure we render any alias as hint."""
ashint = False
if extra_froms:
@@ -1558,20 +1759,21 @@ class MSSQLCompiler(compiler.SQLCompiler):
self, asfrom=True, iscrud=True, ashint=ashint
)
- def delete_extra_from_clause(self, delete_stmt, from_table,
- extra_froms, from_hints, **kw):
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Render the DELETE .. FROM clause specific to MSSQL.
Yes, it has the FROM keyword twice.
"""
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
def visit_empty_set_expr(self, type_):
- return 'SELECT 1 WHERE 1!=1'
+ return "SELECT 1 WHERE 1!=1"
class MSSQLStrictCompiler(MSSQLCompiler):
@@ -1583,20 +1785,21 @@ class MSSQLStrictCompiler(MSSQLCompiler):
binds are used.
"""
+
ansi_bind_rules = True
def visit_in_op_binary(self, binary, operator, **kw):
- kw['literal_binds'] = True
+ kw["literal_binds"] = True
return "%s IN %s" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_notin_op_binary(self, binary, operator, **kw):
- kw['literal_binds'] = True
+ kw["literal_binds"] = True
return "%s NOT IN %s" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def render_literal_value(self, value, type_):
@@ -1615,23 +1818,28 @@ class MSSQLStrictCompiler(MSSQLCompiler):
# SQL Server wants single quotes around the date string.
return "'" + str(value) + "'"
else:
- return super(MSSQLStrictCompiler, self).\
- render_literal_value(value, type_)
+ return super(MSSQLStrictCompiler, self).render_literal_value(
+ value, type_
+ )
class MSDDLCompiler(compiler.DDLCompiler):
-
def get_column_specification(self, column, **kwargs):
colspec = (
- self.preparer.format_column(column) + " "
+ self.preparer.format_column(column)
+ + " "
+ self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ column.type, type_expression=column
+ )
)
if column.nullable is not None:
- if not column.nullable or column.primary_key or \
- isinstance(column.default, sa_schema.Sequence) or \
- column.autoincrement is True:
+ if (
+ not column.nullable
+ or column.primary_key
+ or isinstance(column.default, sa_schema.Sequence)
+ or column.autoincrement is True
+ ):
colspec += " NOT NULL"
else:
colspec += " NULL"
@@ -1639,15 +1847,18 @@ class MSDDLCompiler(compiler.DDLCompiler):
if column.table is None:
raise exc.CompileError(
"mssql requires Table-bound columns "
- "in order to generate DDL")
+ "in order to generate DDL"
+ )
# install an IDENTITY Sequence if we either a sequence or an implicit
# IDENTITY column
if isinstance(column.default, sa_schema.Sequence):
- if (column.default.start is not None or
- column.default.increment is not None or
- column is not column.table._autoincrement_column):
+ if (
+ column.default.start is not None
+ or column.default.increment is not None
+ or column is not column.table._autoincrement_column
+ ):
util.warn_deprecated(
"Use of Sequence with SQL Server in order to affect the "
"parameters of the IDENTITY value is deprecated, as "
@@ -1655,18 +1866,23 @@ class MSDDLCompiler(compiler.DDLCompiler):
"will correspond to an actual SQL Server "
"CREATE SEQUENCE in "
"a future release. Please use the mssql_identity_start "
- "and mssql_identity_increment parameters.")
+ "and mssql_identity_increment parameters."
+ )
if column.default.start == 0:
start = 0
else:
start = column.default.start or 1
- colspec += " IDENTITY(%s,%s)" % (start,
- column.default.increment or 1)
- elif column is column.table._autoincrement_column or \
- column.autoincrement is True:
- start = column.dialect_options['mssql']['identity_start']
- increment = column.dialect_options['mssql']['identity_increment']
+ colspec += " IDENTITY(%s,%s)" % (
+ start,
+ column.default.increment or 1,
+ )
+ elif (
+ column is column.table._autoincrement_column
+ or column.autoincrement is True
+ ):
+ start = column.dialect_options["mssql"]["identity_start"]
+ increment = column.dialect_options["mssql"]["identity_increment"]
colspec += " IDENTITY(%s,%s)" % (start, increment)
else:
default = self.get_column_default_string(column)
@@ -1684,84 +1900,88 @@ class MSDDLCompiler(compiler.DDLCompiler):
text += "UNIQUE "
# handle clustering option
- clustered = index.dialect_options['mssql']['clustered']
+ clustered = index.dialect_options["mssql"]["clustered"]
if clustered is not None:
if clustered:
text += "CLUSTERED "
else:
text += "NONCLUSTERED "
- text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table),
- ', '.join(
- self.sql_compiler.process(expr,
- include_table=False,
- literal_binds=True) for
- expr in index.expressions)
- )
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(index.table),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
# handle other included columns
- if index.dialect_options['mssql']['include']:
- inclusions = [index.table.c[col]
- if isinstance(col, util.string_types) else col
- for col in
- index.dialect_options['mssql']['include']
- ]
+ if index.dialect_options["mssql"]["include"]:
+ inclusions = [
+ index.table.c[col]
+ if isinstance(col, util.string_types)
+ else col
+ for col in index.dialect_options["mssql"]["include"]
+ ]
- text += " INCLUDE (%s)" \
- % ', '.join([preparer.quote(c.name)
- for c in inclusions])
+ text += " INCLUDE (%s)" % ", ".join(
+ [preparer.quote(c.name) for c in inclusions]
+ )
return text
def visit_drop_index(self, drop):
return "\nDROP INDEX %s ON %s" % (
self._prepared_index_name(drop.element, include_schema=False),
- self.preparer.format_table(drop.element.table)
+ self.preparer.format_table(drop.element.table),
)
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
- text += "CONSTRAINT %s " % \
- self.preparer.format_constraint(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
text += "PRIMARY KEY "
- clustered = constraint.dialect_options['mssql']['clustered']
+ clustered = constraint.dialect_options["mssql"]["clustered"]
if clustered is not None:
if clustered:
text += "CLUSTERED "
else:
text += "NONCLUSTERED "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in constraint)
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
text += self.define_constraint_deferrability(constraint)
return text
def visit_unique_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
- text += "CONSTRAINT %s " % \
- self.preparer.format_constraint(constraint)
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
text += "UNIQUE "
- clustered = constraint.dialect_options['mssql']['clustered']
+ clustered = constraint.dialect_options["mssql"]["clustered"]
if clustered is not None:
if clustered:
text += "CLUSTERED "
else:
text += "NONCLUSTERED "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in constraint)
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -1771,8 +1991,11 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
def __init__(self, dialect):
super(MSIdentifierPreparer, self).__init__(
- dialect, initial_quote='[',
- final_quote=']', quote_case_sensitive_collations=False)
+ dialect,
+ initial_quote="[",
+ final_quote="]",
+ quote_case_sensitive_collations=False,
+ )
def _escape_identifier(self, value):
return value
@@ -1783,7 +2006,9 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
dbname, owner = _schema_elements(schema)
if dbname:
result = "%s.%s" % (
- self.quote(dbname, force), self.quote(owner, force))
+ self.quote(dbname, force),
+ self.quote(owner, force),
+ )
elif owner:
result = self.quote(owner, force)
else:
@@ -1794,16 +2019,37 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
def _db_plus_owner_listing(fn):
def wrap(dialect, connection, schema=None, **kw):
dbname, owner = _owner_plus_db(dialect, schema)
- return _switch_db(dbname, connection, fn, dialect, connection,
- dbname, owner, schema, **kw)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
return update_wrapper(wrap, fn)
def _db_plus_owner(fn):
def wrap(dialect, connection, tablename, schema=None, **kw):
dbname, owner = _owner_plus_db(dialect, schema)
- return _switch_db(dbname, connection, fn, dialect, connection,
- tablename, dbname, owner, schema, **kw)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ tablename,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
return update_wrapper(wrap, fn)
@@ -1837,9 +2083,9 @@ def _schema_elements(schema):
for token in re.split(r"(\[|\]|\.)", schema):
if not token:
continue
- if token == '[':
+ if token == "[":
bracket = True
- elif token == ']':
+ elif token == "]":
bracket = False
elif not bracket and token == ".":
push.append(symbol)
@@ -1857,7 +2103,7 @@ def _schema_elements(schema):
class MSDialect(default.DefaultDialect):
- name = 'mssql'
+ name = "mssql"
supports_default_values = True
supports_empty_insert = False
execution_ctx_cls = MSExecutionContext
@@ -1871,9 +2117,9 @@ class MSDialect(default.DefaultDialect):
sqltypes.Time: TIME,
}
- engine_config_types = default.DefaultDialect.engine_config_types.union([
- ('legacy_schema_aliasing', util.asbool),
- ])
+ engine_config_types = default.DefaultDialect.engine_config_types.union(
+ [("legacy_schema_aliasing", util.asbool)]
+ )
ischema_names = ischema_names
@@ -1890,36 +2136,30 @@ class MSDialect(default.DefaultDialect):
preparer = MSIdentifierPreparer
construct_arguments = [
- (sa_schema.PrimaryKeyConstraint, {
- "clustered": None
- }),
- (sa_schema.UniqueConstraint, {
- "clustered": None
- }),
- (sa_schema.Index, {
- "clustered": None,
- "include": None
- }),
- (sa_schema.Column, {
- "identity_start": 1,
- "identity_increment": 1
- })
+ (sa_schema.PrimaryKeyConstraint, {"clustered": None}),
+ (sa_schema.UniqueConstraint, {"clustered": None}),
+ (sa_schema.Index, {"clustered": None, "include": None}),
+ (sa_schema.Column, {"identity_start": 1, "identity_increment": 1}),
]
- def __init__(self,
- query_timeout=None,
- use_scope_identity=True,
- max_identifier_length=None,
- schema_name="dbo",
- isolation_level=None,
- deprecate_large_types=None,
- legacy_schema_aliasing=False, **opts):
+ def __init__(
+ self,
+ query_timeout=None,
+ use_scope_identity=True,
+ max_identifier_length=None,
+ schema_name="dbo",
+ isolation_level=None,
+ deprecate_large_types=None,
+ legacy_schema_aliasing=False,
+ **opts
+ ):
self.query_timeout = int(query_timeout or 0)
self.schema_name = schema_name
self.use_scope_identity = use_scope_identity
- self.max_identifier_length = int(max_identifier_length or 0) or \
- self.max_identifier_length
+ self.max_identifier_length = (
+ int(max_identifier_length or 0) or self.max_identifier_length
+ )
self.deprecate_large_types = deprecate_large_types
self.legacy_schema_aliasing = legacy_schema_aliasing
@@ -1936,27 +2176,33 @@ class MSDialect(default.DefaultDialect):
# SQL Server does not support RELEASE SAVEPOINT
pass
- _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED',
- 'READ COMMITTED', 'REPEATABLE READ',
- 'SNAPSHOT'])
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "SNAPSHOT",
+ ]
+ )
def set_isolation_level(self, connection, level):
- level = level.replace('_', ' ')
+ level = level.replace("_", " ")
if level not in self._isolation_lookup:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
- "Valid isolation levels for %s are %s" %
- (level, self.name, ", ".join(self._isolation_lookup))
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
)
cursor = connection.cursor()
- cursor.execute(
- "SET TRANSACTION ISOLATION LEVEL %s" % level)
+ cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level)
cursor.close()
def get_isolation_level(self, connection):
if self.server_version_info < MS_2005_VERSION:
raise NotImplementedError(
- "Can't fetch isolation level prior to SQL Server 2005")
+ "Can't fetch isolation level prior to SQL Server 2005"
+ )
last_error = None
@@ -1964,7 +2210,8 @@ class MSDialect(default.DefaultDialect):
for view in views:
cursor = connection.cursor()
try:
- cursor.execute("""
+ cursor.execute(
+ """
SELECT CASE transaction_isolation_level
WHEN 0 THEN NULL
WHEN 1 THEN 'READ UNCOMMITTED'
@@ -1974,7 +2221,9 @@ class MSDialect(default.DefaultDialect):
WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL
FROM %s
where session_id = @@SPID
- """ % view)
+ """
+ % view
+ )
val = cursor.fetchone()[0]
except self.dbapi.Error as err:
# Python3 scoping rules
@@ -1987,7 +2236,8 @@ class MSDialect(default.DefaultDialect):
else:
util.warn(
"Could not fetch transaction isolation level, "
- "tried views: %s; final error was: %s" % (views, last_error))
+ "tried views: %s; final error was: %s" % (views, last_error)
+ )
raise NotImplementedError(
"Can't fetch isolation level on this particular "
@@ -2000,8 +2250,10 @@ class MSDialect(default.DefaultDialect):
def on_connect(self):
if self.isolation_level is not None:
+
def connect(conn):
self.set_isolation_level(conn, self.isolation_level)
+
return connect
else:
return None
@@ -2010,16 +2262,20 @@ class MSDialect(default.DefaultDialect):
if self.server_version_info[0] not in list(range(8, 17)):
util.warn(
"Unrecognized server version info '%s'. Some SQL Server "
- "features may not function properly." %
- ".".join(str(x) for x in self.server_version_info))
- if self.server_version_info >= MS_2005_VERSION and \
- 'implicit_returning' not in self.__dict__:
+ "features may not function properly."
+ % ".".join(str(x) for x in self.server_version_info)
+ )
+ if (
+ self.server_version_info >= MS_2005_VERSION
+ and "implicit_returning" not in self.__dict__
+ ):
self.implicit_returning = True
if self.server_version_info >= MS_2008_VERSION:
self.supports_multivalues_insert = True
if self.deprecate_large_types is None:
- self.deprecate_large_types = \
+ self.deprecate_large_types = (
self.server_version_info >= MS_2012_VERSION
+ )
def _get_default_schema_name(self, connection):
if self.server_version_info < MS_2005_VERSION:
@@ -2039,17 +2295,19 @@ class MSDialect(default.DefaultDialect):
whereclause = columns.c.table_name == tablename
if owner:
- whereclause = sql.and_(whereclause,
- columns.c.table_schema == owner)
+ whereclause = sql.and_(
+ whereclause, columns.c.table_schema == owner
+ )
s = sql.select([columns], whereclause)
c = connection.execute(s)
return c.first() is not None
@reflection.cache
def get_schema_names(self, connection, **kw):
- s = sql.select([ischema.schemata.c.schema_name],
- order_by=[ischema.schemata.c.schema_name]
- )
+ s = sql.select(
+ [ischema.schemata.c.schema_name],
+ order_by=[ischema.schemata.c.schema_name],
+ )
schema_names = [r[0] for r in connection.execute(s)]
return schema_names
@@ -2057,12 +2315,13 @@ class MSDialect(default.DefaultDialect):
@_db_plus_owner_listing
def get_table_names(self, connection, dbname, owner, schema, **kw):
tables = ischema.tables
- s = sql.select([tables.c.table_name],
- sql.and_(
- tables.c.table_schema == owner,
- tables.c.table_type == 'BASE TABLE'
- ),
- order_by=[tables.c.table_name]
+ s = sql.select(
+ [tables.c.table_name],
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == "BASE TABLE",
+ ),
+ order_by=[tables.c.table_name],
)
table_names = [r[0] for r in connection.execute(s)]
return table_names
@@ -2071,12 +2330,12 @@ class MSDialect(default.DefaultDialect):
@_db_plus_owner_listing
def get_view_names(self, connection, dbname, owner, schema, **kw):
tables = ischema.tables
- s = sql.select([tables.c.table_name],
- sql.and_(
- tables.c.table_schema == owner,
- tables.c.table_type == 'VIEW'
- ),
- order_by=[tables.c.table_name]
+ s = sql.select(
+ [tables.c.table_name],
+ sql.and_(
+ tables.c.table_schema == owner, tables.c.table_type == "VIEW"
+ ),
+ order_by=[tables.c.table_name],
)
view_names = [r[0] for r in connection.execute(s)]
return view_names
@@ -2090,30 +2349,33 @@ class MSDialect(default.DefaultDialect):
return []
rp = connection.execute(
- sql.text("select ind.index_id, ind.is_unique, ind.name "
- "from sys.indexes as ind join sys.tables as tab on "
- "ind.object_id=tab.object_id "
- "join sys.schemas as sch on sch.schema_id=tab.schema_id "
- "where tab.name = :tabname "
- "and sch.name=:schname "
- "and ind.is_primary_key=0 and ind.type != 0",
- bindparams=[
- sql.bindparam('tabname', tablename,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
- ],
- typemap={
- 'name': sqltypes.Unicode()
- }
- )
+ sql.text(
+ "select ind.index_id, ind.is_unique, ind.name "
+ "from sys.indexes as ind join sys.tables as tab on "
+ "ind.object_id=tab.object_id "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name = :tabname "
+ "and sch.name=:schname "
+ "and ind.is_primary_key=0 and ind.type != 0",
+ bindparams=[
+ sql.bindparam(
+ "tabname",
+ tablename,
+ sqltypes.String(convert_unicode=True),
+ ),
+ sql.bindparam(
+ "schname", owner, sqltypes.String(convert_unicode=True)
+ ),
+ ],
+ typemap={"name": sqltypes.Unicode()},
+ )
)
indexes = {}
for row in rp:
- indexes[row['index_id']] = {
- 'name': row['name'],
- 'unique': row['is_unique'] == 1,
- 'column_names': []
+ indexes[row["index_id"]] = {
+ "name": row["name"],
+ "unique": row["is_unique"] == 1,
+ "column_names": [],
}
rp = connection.execute(
sql.text(
@@ -2127,24 +2389,29 @@ class MSDialect(default.DefaultDialect):
"where tab.name=:tabname "
"and sch.name=:schname",
bindparams=[
- sql.bindparam('tabname', tablename,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
+ sql.bindparam(
+ "tabname",
+ tablename,
+ sqltypes.String(convert_unicode=True),
+ ),
+ sql.bindparam(
+ "schname", owner, sqltypes.String(convert_unicode=True)
+ ),
],
- typemap={'name': sqltypes.Unicode()}
- ),
+ typemap={"name": sqltypes.Unicode()},
+ )
)
for row in rp:
- if row['index_id'] in indexes:
- indexes[row['index_id']]['column_names'].append(row['name'])
+ if row["index_id"] in indexes:
+ indexes[row["index_id"]]["column_names"].append(row["name"])
return list(indexes.values())
@reflection.cache
@_db_plus_owner
- def get_view_definition(self, connection, viewname,
- dbname, owner, schema, **kw):
+ def get_view_definition(
+ self, connection, viewname, dbname, owner, schema, **kw
+ ):
rp = connection.execute(
sql.text(
"select definition from sys.sql_modules as mod, "
@@ -2155,11 +2422,15 @@ class MSDialect(default.DefaultDialect):
"views.schema_id=sch.schema_id and "
"views.name=:viewname and sch.name=:schname",
bindparams=[
- sql.bindparam('viewname', viewname,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
- ]
+ sql.bindparam(
+ "viewname",
+ viewname,
+ sqltypes.String(convert_unicode=True),
+ ),
+ sql.bindparam(
+ "schname", owner, sqltypes.String(convert_unicode=True)
+ ),
+ ],
)
)
@@ -2173,12 +2444,15 @@ class MSDialect(default.DefaultDialect):
# Get base columns
columns = ischema.columns
if owner:
- whereclause = sql.and_(columns.c.table_name == tablename,
- columns.c.table_schema == owner)
+ whereclause = sql.and_(
+ columns.c.table_name == tablename,
+ columns.c.table_schema == owner,
+ )
else:
whereclause = columns.c.table_name == tablename
- s = sql.select([columns], whereclause,
- order_by=[columns.c.ordinal_position])
+ s = sql.select(
+ [columns], whereclause, order_by=[columns.c.ordinal_position]
+ )
c = connection.execute(s)
cols = []
@@ -2186,57 +2460,76 @@ class MSDialect(default.DefaultDialect):
row = c.fetchone()
if row is None:
break
- (name, type, nullable, charlen,
- numericprec, numericscale, default, collation) = (
+ (
+ name,
+ type,
+ nullable,
+ charlen,
+ numericprec,
+ numericscale,
+ default,
+ collation,
+ ) = (
row[columns.c.column_name],
row[columns.c.data_type],
- row[columns.c.is_nullable] == 'YES',
+ row[columns.c.is_nullable] == "YES",
row[columns.c.character_maximum_length],
row[columns.c.numeric_precision],
row[columns.c.numeric_scale],
row[columns.c.column_default],
- row[columns.c.collation_name]
+ row[columns.c.collation_name],
)
coltype = self.ischema_names.get(type, None)
kwargs = {}
- if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText,
- MSNText, MSBinary, MSVarBinary,
- sqltypes.LargeBinary):
+ if coltype in (
+ MSString,
+ MSChar,
+ MSNVarchar,
+ MSNChar,
+ MSText,
+ MSNText,
+ MSBinary,
+ MSVarBinary,
+ sqltypes.LargeBinary,
+ ):
if charlen == -1:
charlen = None
- kwargs['length'] = charlen
+ kwargs["length"] = charlen
if collation:
- kwargs['collation'] = collation
+ kwargs["collation"] = collation
if coltype is None:
util.warn(
- "Did not recognize type '%s' of column '%s'" %
- (type, name))
+ "Did not recognize type '%s' of column '%s'" % (type, name)
+ )
coltype = sqltypes.NULLTYPE
else:
- if issubclass(coltype, sqltypes.Numeric) and \
- coltype is not MSReal:
- kwargs['scale'] = numericscale
- kwargs['precision'] = numericprec
+ if (
+ issubclass(coltype, sqltypes.Numeric)
+ and coltype is not MSReal
+ ):
+ kwargs["scale"] = numericscale
+ kwargs["precision"] = numericprec
coltype = coltype(**kwargs)
cdict = {
- 'name': name,
- 'type': coltype,
- 'nullable': nullable,
- 'default': default,
- 'autoincrement': False,
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": False,
}
cols.append(cdict)
# autoincrement and identity
colmap = {}
for col in cols:
- colmap[col['name']] = col
+ colmap[col["name"]] = col
# We also run an sp_columns to check for identity columns:
- cursor = connection.execute("sp_columns @table_name = '%s', "
- "@table_owner = '%s'"
- % (tablename, owner))
+ cursor = connection.execute(
+ "sp_columns @table_name = '%s', "
+ "@table_owner = '%s'" % (tablename, owner)
+ )
ic = None
while True:
row = cursor.fetchone()
@@ -2245,10 +2538,10 @@ class MSDialect(default.DefaultDialect):
(col_name, type_name) = row[3], row[5]
if type_name.endswith("identity") and col_name in colmap:
ic = col_name
- colmap[col_name]['autoincrement'] = True
- colmap[col_name]['dialect_options'] = {
- 'mssql_identity_start': 1,
- 'mssql_identity_increment': 1
+ colmap[col_name]["autoincrement"] = True
+ colmap[col_name]["dialect_options"] = {
+ "mssql_identity_start": 1,
+ "mssql_identity_increment": 1,
}
break
cursor.close()
@@ -2262,64 +2555,74 @@ class MSDialect(default.DefaultDialect):
row = cursor.first()
if row is not None and row[0] is not None:
- colmap[ic]['dialect_options'].update({
- 'mssql_identity_start': int(row[0]),
- 'mssql_identity_increment': int(row[1])
- })
+ colmap[ic]["dialect_options"].update(
+ {
+ "mssql_identity_start": int(row[0]),
+ "mssql_identity_increment": int(row[1]),
+ }
+ )
return cols
@reflection.cache
@_db_plus_owner
- def get_pk_constraint(self, connection, tablename,
- dbname, owner, schema, **kw):
+ def get_pk_constraint(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
pkeys = []
TC = ischema.constraints
- C = ischema.key_constraints.alias('C')
+ C = ischema.key_constraints.alias("C")
# Primary key constraints
- s = sql.select([C.c.column_name,
- TC.c.constraint_type,
- C.c.constraint_name],
- sql.and_(TC.c.constraint_name == C.c.constraint_name,
- TC.c.table_schema == C.c.table_schema,
- C.c.table_name == tablename,
- C.c.table_schema == owner)
- )
+ s = sql.select(
+ [C.c.column_name, TC.c.constraint_type, C.c.constraint_name],
+ sql.and_(
+ TC.c.constraint_name == C.c.constraint_name,
+ TC.c.table_schema == C.c.table_schema,
+ C.c.table_name == tablename,
+ C.c.table_schema == owner,
+ ),
+ )
c = connection.execute(s)
constraint_name = None
for row in c:
- if 'PRIMARY' in row[TC.c.constraint_type.name]:
+ if "PRIMARY" in row[TC.c.constraint_type.name]:
pkeys.append(row[0])
if constraint_name is None:
constraint_name = row[C.c.constraint_name.name]
- return {'constrained_columns': pkeys, 'name': constraint_name}
+ return {"constrained_columns": pkeys, "name": constraint_name}
@reflection.cache
@_db_plus_owner
- def get_foreign_keys(self, connection, tablename,
- dbname, owner, schema, **kw):
+ def get_foreign_keys(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
RR = ischema.ref_constraints
- C = ischema.key_constraints.alias('C')
- R = ischema.key_constraints.alias('R')
+ C = ischema.key_constraints.alias("C")
+ R = ischema.key_constraints.alias("R")
# Foreign key constraints
- s = sql.select([C.c.column_name,
- R.c.table_schema, R.c.table_name, R.c.column_name,
- RR.c.constraint_name, RR.c.match_option,
- RR.c.update_rule,
- RR.c.delete_rule],
- sql.and_(C.c.table_name == tablename,
- C.c.table_schema == owner,
- RR.c.constraint_schema == C.c.table_schema,
- C.c.constraint_name == RR.c.constraint_name,
- R.c.constraint_name ==
- RR.c.unique_constraint_name,
- R.c.constraint_schema ==
- RR.c.unique_constraint_schema,
- C.c.ordinal_position == R.c.ordinal_position
- ),
- order_by=[RR.c.constraint_name, R.c.ordinal_position]
- )
+ s = sql.select(
+ [
+ C.c.column_name,
+ R.c.table_schema,
+ R.c.table_name,
+ R.c.column_name,
+ RR.c.constraint_name,
+ RR.c.match_option,
+ RR.c.update_rule,
+ RR.c.delete_rule,
+ ],
+ sql.and_(
+ C.c.table_name == tablename,
+ C.c.table_schema == owner,
+ RR.c.constraint_schema == C.c.table_schema,
+ C.c.constraint_name == RR.c.constraint_name,
+ R.c.constraint_name == RR.c.unique_constraint_name,
+ R.c.constraint_schema == RR.c.unique_constraint_schema,
+ C.c.ordinal_position == R.c.ordinal_position,
+ ),
+ order_by=[RR.c.constraint_name, R.c.ordinal_position],
+ )
# group rows by constraint ID, to handle multi-column FKs
fkeys = []
@@ -2327,11 +2630,11 @@ class MSDialect(default.DefaultDialect):
def fkey_rec():
return {
- 'name': None,
- 'constrained_columns': [],
- 'referred_schema': None,
- 'referred_table': None,
- 'referred_columns': []
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
}
fkeys = util.defaultdict(fkey_rec)
@@ -2340,17 +2643,18 @@ class MSDialect(default.DefaultDialect):
scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
rec = fkeys[rfknm]
- rec['name'] = rfknm
- if not rec['referred_table']:
- rec['referred_table'] = rtbl
+ rec["name"] = rfknm
+ if not rec["referred_table"]:
+ rec["referred_table"] = rtbl
if schema is not None or owner != rschema:
if dbname:
rschema = dbname + "." + rschema
- rec['referred_schema'] = rschema
+ rec["referred_schema"] = rschema
- local_cols, remote_cols = \
- rec['constrained_columns'],\
- rec['referred_columns']
+ local_cols, remote_cols = (
+ rec["constrained_columns"],
+ rec["referred_columns"],
+ )
local_cols.append(scol)
remote_cols.append(rcol)