diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 1238 |
1 files changed, 771 insertions, 467 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9269225d3..161297015 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -655,9 +655,22 @@ from ...sql import compiler, expression, util as sql_util, quoted_name from ... import engine from ...engine import reflection, default from ... import types as sqltypes -from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ - FLOAT, DATETIME, DATE, BINARY, \ - TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR +from ...types import ( + INTEGER, + BIGINT, + SMALLINT, + DECIMAL, + NUMERIC, + FLOAT, + DATETIME, + DATE, + BINARY, + TEXT, + VARCHAR, + NVARCHAR, + CHAR, + NCHAR, +) from ...util import update_wrapper @@ -672,48 +685,202 @@ MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) RESERVED_WORDS = set( - ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', - 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', - 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', - 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', - 'containstable', 'continue', 'convert', 'create', 'cross', 'current', - 'current_date', 'current_time', 'current_timestamp', 'current_user', - 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', - 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', - 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', - 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', - 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', - 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', - 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', - 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', - 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', - 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', - 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', - 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', - 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', - 'reconfigure', 'references', 'replication', 'restore', 'restrict', - 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', - 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', - 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', - 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', - 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', - 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', - 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', - 'writetext', - ]) + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", + ] +) class REAL(sqltypes.REAL): - __visit_name__ = 'REAL' + __visit_name__ = "REAL" def __init__(self, **kw): # REAL is a synonym for FLOAT(24) on SQL server - kw['precision'] = 24 + kw["precision"] = 24 super(REAL, self).__init__(**kw) class TINYINT(sqltypes.Integer): - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" # MSSQL DATE/TIME types have varied behavior, sometimes returning @@ -721,14 +888,15 @@ class TINYINT(sqltypes.Integer): # filter bind parameters into datetime objects (required by pyodbc, # not sure about other dialects). -class _MSDate(sqltypes.Date): +class _MSDate(sqltypes.Date): def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process _reg = re.compile(r"(\d+)-(\d+)-(\d+)") @@ -741,18 +909,16 @@ class _MSDate(sqltypes.Date): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a date value" % (value, )) - return datetime.date(*[ - int(x or 0) - for x in m.groups() - ]) + "could not parse %r as a date value" % (value,) + ) + return datetime.date(*[int(x or 0) for x in m.groups()]) else: return value + return process class TIME(sqltypes.TIME): - def __init__(self, precision=None, **kwargs): self.precision = precision super(TIME, self).__init__() @@ -763,10 +929,12 @@ class TIME(sqltypes.TIME): def process(value): if isinstance(value, datetime.datetime): value = datetime.datetime.combine( - self.__zero_date, value.time()) + self.__zero_date, value.time() + ) elif isinstance(value, datetime.time): value = datetime.datetime.combine(self.__zero_date, value) return value + return process _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") @@ -779,24 +947,26 @@ class TIME(sqltypes.TIME): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a time value" % (value, )) - return datetime.time(*[ - int(x or 0) - for x in m.groups()]) + "could not parse %r as a time value" % (value,) + ) + return datetime.time(*[int(x or 0) for x in m.groups()]) else: return value + return process + + _MSTime = TIME class _DateTimeBase(object): - def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process @@ -805,11 +975,11 @@ class _MSDateTime(_DateTimeBase, sqltypes.DateTime): class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'SMALLDATETIME' + __visit_name__ = "SMALLDATETIME" class DATETIME2(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'DATETIME2' + __visit_name__ = "DATETIME2" def __init__(self, precision=None, **kw): super(DATETIME2, self).__init__(**kw) @@ -818,7 +988,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime): # TODO: is this not an Interval ? class DATETIMEOFFSET(sqltypes.TypeEngine): - __visit_name__ = 'DATETIMEOFFSET' + __visit_name__ = "DATETIMEOFFSET" def __init__(self, precision=None, **kwargs): self.precision = precision @@ -847,7 +1017,7 @@ class TIMESTAMP(sqltypes._Binary): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" # expected by _Binary to be present length = None @@ -866,12 +1036,14 @@ class TIMESTAMP(sqltypes._Binary): def result_processor(self, dialect, coltype): super_ = super(TIMESTAMP, self).result_processor(dialect, coltype) if self.convert_int: + def process(value): value = super_(value) if value is not None: # https://stackoverflow.com/a/30403242/34549 - value = int(codecs.encode(value, 'hex'), 16) + value = int(codecs.encode(value, "hex"), 16) return value + return process else: return super_ @@ -898,7 +1070,7 @@ class ROWVERSION(TIMESTAMP): """ - __visit_name__ = 'ROWVERSION' + __visit_name__ = "ROWVERSION" class NTEXT(sqltypes.UnicodeText): @@ -906,7 +1078,7 @@ class NTEXT(sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" - __visit_name__ = 'NTEXT' + __visit_name__ = "NTEXT" class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): @@ -925,11 +1097,12 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): """ - __visit_name__ = 'VARBINARY' + + __visit_name__ = "VARBINARY" class IMAGE(sqltypes.LargeBinary): - __visit_name__ = 'IMAGE' + __visit_name__ = "IMAGE" class XML(sqltypes.Text): @@ -943,19 +1116,20 @@ class XML(sqltypes.Text): .. versionadded:: 1.1.11 """ - __visit_name__ = 'XML' + + __visit_name__ = "XML" class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" class MONEY(sqltypes.TypeEngine): - __visit_name__ = 'MONEY' + __visit_name__ = "MONEY" class SMALLMONEY(sqltypes.TypeEngine): - __visit_name__ = 'SMALLMONEY' + __visit_name__ = "SMALLMONEY" class UNIQUEIDENTIFIER(sqltypes.TypeEngine): @@ -963,7 +1137,8 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine): class SQL_VARIANT(sqltypes.TypeEngine): - __visit_name__ = 'SQL_VARIANT' + __visit_name__ = "SQL_VARIANT" + # old names. MSDateTime = _MSDateTime @@ -990,36 +1165,36 @@ MSUniqueIdentifier = UNIQUEIDENTIFIER MSVariant = SQL_VARIANT ischema_names = { - 'int': INTEGER, - 'bigint': BIGINT, - 'smallint': SMALLINT, - 'tinyint': TINYINT, - 'varchar': VARCHAR, - 'nvarchar': NVARCHAR, - 'char': CHAR, - 'nchar': NCHAR, - 'text': TEXT, - 'ntext': NTEXT, - 'decimal': DECIMAL, - 'numeric': NUMERIC, - 'float': FLOAT, - 'datetime': DATETIME, - 'datetime2': DATETIME2, - 'datetimeoffset': DATETIMEOFFSET, - 'date': DATE, - 'time': TIME, - 'smalldatetime': SMALLDATETIME, - 'binary': BINARY, - 'varbinary': VARBINARY, - 'bit': BIT, - 'real': REAL, - 'image': IMAGE, - 'xml': XML, - 'timestamp': TIMESTAMP, - 'money': MONEY, - 'smallmoney': SMALLMONEY, - 'uniqueidentifier': UNIQUEIDENTIFIER, - 'sql_variant': SQL_VARIANT, + "int": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "tinyint": TINYINT, + "varchar": VARCHAR, + "nvarchar": NVARCHAR, + "char": CHAR, + "nchar": NCHAR, + "text": TEXT, + "ntext": NTEXT, + "decimal": DECIMAL, + "numeric": NUMERIC, + "float": FLOAT, + "datetime": DATETIME, + "datetime2": DATETIME2, + "datetimeoffset": DATETIMEOFFSET, + "date": DATE, + "time": TIME, + "smalldatetime": SMALLDATETIME, + "binary": BINARY, + "varbinary": VARBINARY, + "bit": BIT, + "real": REAL, + "image": IMAGE, + "xml": XML, + "timestamp": TIMESTAMP, + "money": MONEY, + "smallmoney": SMALLMONEY, + "uniqueidentifier": UNIQUEIDENTIFIER, + "sql_variant": SQL_VARIANT, } @@ -1030,8 +1205,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): """ - if getattr(type_, 'collation', None): - collation = 'COLLATE %s' % type_.collation + if getattr(type_, "collation", None): + collation = "COLLATE %s" % type_.collation else: collation = None @@ -1041,15 +1216,14 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): if length: spec = spec + "(%s)" % length - return ' '.join([c for c in (spec, collation) - if c is not None]) + return " ".join([c for c in (spec, collation) if c is not None]) def visit_FLOAT(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is None: return "FLOAT" else: - return "FLOAT(%(precision)s)" % {'precision': precision} + return "FLOAT(%(precision)s)" % {"precision": precision} def visit_TINYINT(self, type_, **kw): return "TINYINT" @@ -1061,7 +1235,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "DATETIMEOFFSET" def visit_TIME(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "TIME(%s)" % precision else: @@ -1074,7 +1248,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "ROWVERSION" def visit_DATETIME2(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "DATETIME2(%s)" % precision else: @@ -1105,7 +1279,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("TEXT", type_) def visit_VARCHAR(self, type_, **kw): - return self._extend("VARCHAR", type_, length=type_.length or 'max') + return self._extend("VARCHAR", type_, length=type_.length or "max") def visit_CHAR(self, type_, **kw): return self._extend("CHAR", type_) @@ -1114,7 +1288,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("NCHAR", type_) def visit_NVARCHAR(self, type_, **kw): - return self._extend("NVARCHAR", type_, length=type_.length or 'max') + return self._extend("NVARCHAR", type_, length=type_.length or "max") def visit_date(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: @@ -1141,10 +1315,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "XML" def visit_VARBINARY(self, type_, **kw): - return self._extend( - "VARBINARY", - type_, - length=type_.length or 'max') + return self._extend("VARBINARY", type_, length=type_.length or "max") def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) @@ -1156,13 +1327,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "MONEY" def visit_SMALLMONEY(self, type_, **kw): - return 'SMALLMONEY' + return "SMALLMONEY" def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" def visit_SQL_VARIANT(self, type_, **kw): - return 'SQL_VARIANT' + return "SQL_VARIANT" class MSExecutionContext(default.DefaultExecutionContext): @@ -1186,41 +1357,44 @@ class MSExecutionContext(default.DefaultExecutionContext): insert_has_sequence = seq_column is not None if insert_has_sequence: - self._enable_identity_insert = \ - seq_column.key in self.compiled_parameters[0] or \ - ( - self.compiled.statement.parameters and ( - ( - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters[0] - ) or ( - not - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters - ) + self._enable_identity_insert = seq_column.key in self.compiled_parameters[ + 0 + ] or ( + self.compiled.statement.parameters + and ( + ( + self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters[0] + ) + or ( + not self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters ) ) + ) else: self._enable_identity_insert = False - self._select_lastrowid = not self.compiled.inline and \ - insert_has_sequence and \ - not self.compiled.returning and \ - not self._enable_identity_insert and \ - not self.executemany + self._select_lastrowid = ( + not self.compiled.inline + and insert_has_sequence + and not self.compiled.returning + and not self._enable_identity_insert + and not self.executemany + ) if self._enable_identity_insert: self.root_connection._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)), + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(tbl) + ), (), - self) + self, + ) def post_exec(self): """Disable IDENTITY_INSERT if enabled.""" @@ -1230,29 +1404,35 @@ class MSExecutionContext(default.DefaultExecutionContext): if self.dialect.use_scope_identity: conn._cursor_execute( self.cursor, - "SELECT scope_identity() AS lastrowid", (), self) + "SELECT scope_identity() AS lastrowid", + (), + self, + ) else: - conn._cursor_execute(self.cursor, - "SELECT @@identity AS lastrowid", - (), - self) + conn._cursor_execute( + self.cursor, "SELECT @@identity AS lastrowid", (), self + ) # fetchall() ensures the cursor is consumed without closing it row = self.cursor.fetchall()[0] self._lastrowid = int(row[0]) - if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.returning: + if ( + self.isinsert or self.isupdate or self.isdelete + ) and self.compiled.returning: self._result_proxy = engine.FullyBufferedResultProxy(self) if self._enable_identity_insert: conn._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table)), + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ), (), - self) + self, + ) def get_lastrowid(self): return self._lastrowid @@ -1262,9 +1442,12 @@ class MSExecutionContext(default.DefaultExecutionContext): try: self.cursor.execute( self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table))) + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ) + ) except Exception: pass @@ -1281,11 +1464,12 @@ class MSSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'microseconds': 'microsecond' - }) + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "microseconds": "microsecond", + }, + ) def __init__(self, *args, **kwargs): self.tablealiases = {} @@ -1298,6 +1482,7 @@ class MSSQLCompiler(compiler.SQLCompiler): else: super_ = getattr(super(MSSQLCompiler, self), fn.__name__) return super_(*arg, **kw) + return decorate def visit_now_func(self, fn, **kw): @@ -1313,20 +1498,22 @@ class MSSQLCompiler(compiler.SQLCompiler): return "LEN%s" % self.function_argspec(fn, **kw) def visit_concat_op_binary(self, binary, operator, **kw): - return "%s + %s" % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "%s + %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def visit_match_op_binary(self, binary, operator, **kw): return "CONTAINS (%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def get_select_precolumns(self, select, **kw): """ MS-SQL puts TOP, it's version of LIMIT here """ @@ -1345,7 +1532,8 @@ class MSSQLCompiler(compiler.SQLCompiler): return s else: return compiler.SQLCompiler.get_select_precolumns( - self, select, **kw) + self, select, **kw + ) def get_from_hint_text(self, table, text): return text @@ -1363,20 +1551,21 @@ class MSSQLCompiler(compiler.SQLCompiler): """ if ( - ( - not select._simple_int_limit and - select._limit_clause is not None - ) or ( - select._offset_clause is not None and - not select._simple_int_offset or select._offset + (not select._simple_int_limit and select._limit_clause is not None) + or ( + select._offset_clause is not None + and not select._simple_int_offset + or select._offset ) - ) and not getattr(select, '_mssql_visit', None): + ) and not getattr(select, "_mssql_visit", None): # to use ROW_NUMBER(), an ORDER BY is required. if not select._order_by_clause.clauses: - raise exc.CompileError('MSSQL requires an order_by when ' - 'using an OFFSET or a non-simple ' - 'LIMIT clause') + raise exc.CompileError( + "MSSQL requires an order_by when " + "using an OFFSET or a non-simple " + "LIMIT clause" + ) _order_by_clauses = [ sql_util.unwrap_label_reference(elem) @@ -1385,24 +1574,31 @@ class MSSQLCompiler(compiler.SQLCompiler): limit_clause = select._limit_clause offset_clause = select._offset_clause - kwargs['select_wraps_for'] = select + kwargs["select_wraps_for"] = select select = select._generate() select._mssql_visit = True - select = select.column( - sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) - .label("mssql_rn")).order_by(None).alias() + select = ( + select.column( + sql.func.ROW_NUMBER() + .over(order_by=_order_by_clauses) + .label("mssql_rn") + ) + .order_by(None) + .alias() + ) - mssql_rn = sql.column('mssql_rn') - limitselect = sql.select([c for c in select.c if - c.key != 'mssql_rn']) + mssql_rn = sql.column("mssql_rn") + limitselect = sql.select( + [c for c in select.c if c.key != "mssql_rn"] + ) if offset_clause is not None: limitselect.append_whereclause(mssql_rn > offset_clause) if limit_clause is not None: limitselect.append_whereclause( - mssql_rn <= (limit_clause + offset_clause)) + mssql_rn <= (limit_clause + offset_clause) + ) else: - limitselect.append_whereclause( - mssql_rn <= (limit_clause)) + limitselect.append_whereclause(mssql_rn <= (limit_clause)) return self.process(limitselect, **kwargs) else: return compiler.SQLCompiler.visit_select(self, select, **kwargs) @@ -1422,35 +1618,38 @@ class MSSQLCompiler(compiler.SQLCompiler): @_with_legacy_schema_aliasing def visit_alias(self, alias, **kw): # translate for schema-qualified table aliases - kw['mssql_aliased'] = alias.original + kw["mssql_aliased"] = alias.original return super(MSSQLCompiler, self).visit_alias(alias, **kw) @_with_legacy_schema_aliasing def visit_column(self, column, add_to_result_map=None, **kw): - if column.table is not None and \ - (not self.isupdate and not self.isdelete) or \ - self.is_subquery(): + if ( + column.table is not None + and (not self.isupdate and not self.isdelete) + or self.is_subquery() + ): # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: converted = expression._corresponding_column_or_error( - t, column) + t, column + ) if add_to_result_map is not None: add_to_result_map( column.name, column.name, (column, column.name, column.key), - column.type + column.type, ) - return super(MSSQLCompiler, self).\ - visit_column(converted, **kw) + return super(MSSQLCompiler, self).visit_column(converted, **kw) return super(MSSQLCompiler, self).visit_column( - column, add_to_result_map=add_to_result_map, **kw) + column, add_to_result_map=add_to_result_map, **kw + ) def _schema_aliased_table(self, table): - if getattr(table, 'schema', None) is not None: + if getattr(table, "schema", None) is not None: if table not in self.tablealiases: self.tablealiases[table] = table.alias() return self.tablealiases[table] @@ -1459,16 +1658,17 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART(%s, %s)' % \ - (field, self.process(extract.expr, **kw)) + return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) def visit_savepoint(self, savepoint_stmt): - return "SAVE TRANSACTION %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_rollback_to_savepoint(self, savepoint_stmt): - return ("ROLLBACK TRANSACTION %s" - % self.preparer.format_savepoint(savepoint_stmt)) + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where @@ -1481,10 +1681,11 @@ class MSSQLCompiler(compiler.SQLCompiler): and not isinstance(binary.right, expression.BindParameter) ): return self.process( - expression.BinaryExpression(binary.right, - binary.left, - binary.operator), - **kwargs) + expression.BinaryExpression( + binary.right, binary.left, binary.operator + ), + **kwargs + ) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def returning_clause(self, stmt, returning_cols): @@ -1497,12 +1698,13 @@ class MSSQLCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(target) columns = [ - self._label_select_column(None, adapter.traverse(c), - True, False, {}) + self._label_select_column( + None, adapter.traverse(c), True, False, {} + ) for c in expression._select_iterables(returning_cols) ] - return 'OUTPUT ' + ', '.join(columns) + return "OUTPUT " + ", ".join(columns) def get_cte_preamble(self, recursive): # SQL Server finds it too inconvenient to accept @@ -1515,13 +1717,14 @@ class MSSQLCompiler(compiler.SQLCompiler): if isinstance(column, expression.Function): return column.label(None) else: - return super(MSSQLCompiler, self).\ - label_select_column(select, column, asfrom) + return super(MSSQLCompiler, self).label_select_column( + select, column, asfrom + ) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which # SQLAlchemy doesn't use - return '' + return "" def order_by_clause(self, select, **kw): order_by = self.process(select._order_by_clause, **kw) @@ -1532,10 +1735,9 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the UPDATE..FROM clause specific to MSSQL. In MSSQL, if the UPDATE statement involves an alias of the table to @@ -1543,13 +1745,12 @@ class MSSQLCompiler(compiler.SQLCompiler): well. Otherwise, it is optional. Here, we add it regardless. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1558,20 +1759,21 @@ class MSSQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. FROM clause specific to MSSQL. Yes, it has the FROM keyword twice. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, type_): - return 'SELECT 1 WHERE 1!=1' + return "SELECT 1 WHERE 1!=1" class MSSQLStrictCompiler(MSSQLCompiler): @@ -1583,20 +1785,21 @@ class MSSQLStrictCompiler(MSSQLCompiler): binds are used. """ + ansi_bind_rules = True def visit_in_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_notin_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s NOT IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def render_literal_value(self, value, type_): @@ -1615,23 +1818,28 @@ class MSSQLStrictCompiler(MSSQLCompiler): # SQL Server wants single quotes around the date string. return "'" + str(value) + "'" else: - return super(MSSQLStrictCompiler, self).\ - render_literal_value(value, type_) + return super(MSSQLStrictCompiler, self).render_literal_value( + value, type_ + ) class MSDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): colspec = ( - self.preparer.format_column(column) + " " + self.preparer.format_column(column) + + " " + self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) ) if column.nullable is not None: - if not column.nullable or column.primary_key or \ - isinstance(column.default, sa_schema.Sequence) or \ - column.autoincrement is True: + if ( + not column.nullable + or column.primary_key + or isinstance(column.default, sa_schema.Sequence) + or column.autoincrement is True + ): colspec += " NOT NULL" else: colspec += " NULL" @@ -1639,15 +1847,18 @@ class MSDDLCompiler(compiler.DDLCompiler): if column.table is None: raise exc.CompileError( "mssql requires Table-bound columns " - "in order to generate DDL") + "in order to generate DDL" + ) # install an IDENTITY Sequence if we either a sequence or an implicit # IDENTITY column if isinstance(column.default, sa_schema.Sequence): - if (column.default.start is not None or - column.default.increment is not None or - column is not column.table._autoincrement_column): + if ( + column.default.start is not None + or column.default.increment is not None + or column is not column.table._autoincrement_column + ): util.warn_deprecated( "Use of Sequence with SQL Server in order to affect the " "parameters of the IDENTITY value is deprecated, as " @@ -1655,18 +1866,23 @@ class MSDDLCompiler(compiler.DDLCompiler): "will correspond to an actual SQL Server " "CREATE SEQUENCE in " "a future release. Please use the mssql_identity_start " - "and mssql_identity_increment parameters.") + "and mssql_identity_increment parameters." + ) if column.default.start == 0: start = 0 else: start = column.default.start or 1 - colspec += " IDENTITY(%s,%s)" % (start, - column.default.increment or 1) - elif column is column.table._autoincrement_column or \ - column.autoincrement is True: - start = column.dialect_options['mssql']['identity_start'] - increment = column.dialect_options['mssql']['identity_increment'] + colspec += " IDENTITY(%s,%s)" % ( + start, + column.default.increment or 1, + ) + elif ( + column is column.table._autoincrement_column + or column.autoincrement is True + ): + start = column.dialect_options["mssql"]["identity_start"] + increment = column.dialect_options["mssql"]["identity_increment"] colspec += " IDENTITY(%s,%s)" % (start, increment) else: default = self.get_column_default_string(column) @@ -1684,84 +1900,88 @@ class MSDDLCompiler(compiler.DDLCompiler): text += "UNIQUE " # handle clustering option - clustered = index.dialect_options['mssql']['clustered'] + clustered = index.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table), - ', '.join( - self.sql_compiler.process(expr, - include_table=False, - literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table(index.table), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) # handle other included columns - if index.dialect_options['mssql']['include']: - inclusions = [index.table.c[col] - if isinstance(col, util.string_types) else col - for col in - index.dialect_options['mssql']['include'] - ] + if index.dialect_options["mssql"]["include"]: + inclusions = [ + index.table.c[col] + if isinstance(col, util.string_types) + else col + for col in index.dialect_options["mssql"]["include"] + ] - text += " INCLUDE (%s)" \ - % ', '.join([preparer.quote(c.name) - for c in inclusions]) + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) return text def visit_drop_index(self, drop): return "\nDROP INDEX %s ON %s" % ( self._prepared_index_name(drop.element, include_schema=False), - self.preparer.format_table(drop.element.table) + self.preparer.format_table(drop.element.table), ) def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "PRIMARY KEY " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "UNIQUE " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text @@ -1771,8 +1991,11 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(MSIdentifierPreparer, self).__init__( - dialect, initial_quote='[', - final_quote=']', quote_case_sensitive_collations=False) + dialect, + initial_quote="[", + final_quote="]", + quote_case_sensitive_collations=False, + ) def _escape_identifier(self, value): return value @@ -1783,7 +2006,9 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): dbname, owner = _schema_elements(schema) if dbname: result = "%s.%s" % ( - self.quote(dbname, force), self.quote(owner, force)) + self.quote(dbname, force), + self.quote(owner, force), + ) elif owner: result = self.quote(owner, force) else: @@ -1794,16 +2019,37 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def _db_plus_owner_listing(fn): def wrap(dialect, connection, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) def _db_plus_owner(fn): def wrap(dialect, connection, tablename, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - tablename, dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + tablename, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) @@ -1837,9 +2083,9 @@ def _schema_elements(schema): for token in re.split(r"(\[|\]|\.)", schema): if not token: continue - if token == '[': + if token == "[": bracket = True - elif token == ']': + elif token == "]": bracket = False elif not bracket and token == ".": push.append(symbol) @@ -1857,7 +2103,7 @@ def _schema_elements(schema): class MSDialect(default.DefaultDialect): - name = 'mssql' + name = "mssql" supports_default_values = True supports_empty_insert = False execution_ctx_cls = MSExecutionContext @@ -1871,9 +2117,9 @@ class MSDialect(default.DefaultDialect): sqltypes.Time: TIME, } - engine_config_types = default.DefaultDialect.engine_config_types.union([ - ('legacy_schema_aliasing', util.asbool), - ]) + engine_config_types = default.DefaultDialect.engine_config_types.union( + [("legacy_schema_aliasing", util.asbool)] + ) ischema_names = ischema_names @@ -1890,36 +2136,30 @@ class MSDialect(default.DefaultDialect): preparer = MSIdentifierPreparer construct_arguments = [ - (sa_schema.PrimaryKeyConstraint, { - "clustered": None - }), - (sa_schema.UniqueConstraint, { - "clustered": None - }), - (sa_schema.Index, { - "clustered": None, - "include": None - }), - (sa_schema.Column, { - "identity_start": 1, - "identity_increment": 1 - }) + (sa_schema.PrimaryKeyConstraint, {"clustered": None}), + (sa_schema.UniqueConstraint, {"clustered": None}), + (sa_schema.Index, {"clustered": None, "include": None}), + (sa_schema.Column, {"identity_start": 1, "identity_increment": 1}), ] - def __init__(self, - query_timeout=None, - use_scope_identity=True, - max_identifier_length=None, - schema_name="dbo", - isolation_level=None, - deprecate_large_types=None, - legacy_schema_aliasing=False, **opts): + def __init__( + self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", + isolation_level=None, + deprecate_large_types=None, + legacy_schema_aliasing=False, + **opts + ): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = use_scope_identity - self.max_identifier_length = int(max_identifier_length or 0) or \ - self.max_identifier_length + self.max_identifier_length = ( + int(max_identifier_length or 0) or self.max_identifier_length + ) self.deprecate_large_types = deprecate_large_types self.legacy_schema_aliasing = legacy_schema_aliasing @@ -1936,27 +2176,33 @@ class MSDialect(default.DefaultDialect): # SQL Server does not support RELEASE SAVEPOINT pass - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'SNAPSHOT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() - cursor.execute( - "SET TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level) cursor.close() def get_isolation_level(self, connection): if self.server_version_info < MS_2005_VERSION: raise NotImplementedError( - "Can't fetch isolation level prior to SQL Server 2005") + "Can't fetch isolation level prior to SQL Server 2005" + ) last_error = None @@ -1964,7 +2210,8 @@ class MSDialect(default.DefaultDialect): for view in views: cursor = connection.cursor() try: - cursor.execute(""" + cursor.execute( + """ SELECT CASE transaction_isolation_level WHEN 0 THEN NULL WHEN 1 THEN 'READ UNCOMMITTED' @@ -1974,7 +2221,9 @@ class MSDialect(default.DefaultDialect): WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL FROM %s where session_id = @@SPID - """ % view) + """ + % view + ) val = cursor.fetchone()[0] except self.dbapi.Error as err: # Python3 scoping rules @@ -1987,7 +2236,8 @@ class MSDialect(default.DefaultDialect): else: util.warn( "Could not fetch transaction isolation level, " - "tried views: %s; final error was: %s" % (views, last_error)) + "tried views: %s; final error was: %s" % (views, last_error) + ) raise NotImplementedError( "Can't fetch isolation level on this particular " @@ -2000,8 +2250,10 @@ class MSDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None @@ -2010,16 +2262,20 @@ class MSDialect(default.DefaultDialect): if self.server_version_info[0] not in list(range(8, 17)): util.warn( "Unrecognized server version info '%s'. Some SQL Server " - "features may not function properly." % - ".".join(str(x) for x in self.server_version_info)) - if self.server_version_info >= MS_2005_VERSION and \ - 'implicit_returning' not in self.__dict__: + "features may not function properly." + % ".".join(str(x) for x in self.server_version_info) + ) + if ( + self.server_version_info >= MS_2005_VERSION + and "implicit_returning" not in self.__dict__ + ): self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True if self.deprecate_large_types is None: - self.deprecate_large_types = \ + self.deprecate_large_types = ( self.server_version_info >= MS_2012_VERSION + ) def _get_default_schema_name(self, connection): if self.server_version_info < MS_2005_VERSION: @@ -2039,17 +2295,19 @@ class MSDialect(default.DefaultDialect): whereclause = columns.c.table_name == tablename if owner: - whereclause = sql.and_(whereclause, - columns.c.table_schema == owner) + whereclause = sql.and_( + whereclause, columns.c.table_schema == owner + ) s = sql.select([columns], whereclause) c = connection.execute(s) return c.first() is not None @reflection.cache def get_schema_names(self, connection, **kw): - s = sql.select([ischema.schemata.c.schema_name], - order_by=[ischema.schemata.c.schema_name] - ) + s = sql.select( + [ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name], + ) schema_names = [r[0] for r in connection.execute(s)] return schema_names @@ -2057,12 +2315,13 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_table_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'BASE TABLE' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "BASE TABLE", + ), + order_by=[tables.c.table_name], ) table_names = [r[0] for r in connection.execute(s)] return table_names @@ -2071,12 +2330,12 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_view_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'VIEW' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, tables.c.table_type == "VIEW" + ), + order_by=[tables.c.table_name], ) view_names = [r[0] for r in connection.execute(s)] return view_names @@ -2090,30 +2349,33 @@ class MSDialect(default.DefaultDialect): return [] rp = connection.execute( - sql.text("select ind.index_id, ind.is_unique, ind.name " - "from sys.indexes as ind join sys.tables as tab on " - "ind.object_id=tab.object_id " - "join sys.schemas as sch on sch.schema_id=tab.schema_id " - "where tab.name = :tabname " - "and sch.name=:schname " - "and ind.is_primary_key=0 and ind.type != 0", - bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ], - typemap={ - 'name': sqltypes.Unicode() - } - ) + sql.text( + "select ind.index_id, ind.is_unique, ind.name " + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0 and ind.type != 0", + bindparams=[ + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], + typemap={"name": sqltypes.Unicode()}, + ) ) indexes = {} for row in rp: - indexes[row['index_id']] = { - 'name': row['name'], - 'unique': row['is_unique'] == 1, - 'column_names': [] + indexes[row["index_id"]] = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], } rp = connection.execute( sql.text( @@ -2127,24 +2389,29 @@ class MSDialect(default.DefaultDialect): "where tab.name=:tabname " "and sch.name=:schname", bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), ], - typemap={'name': sqltypes.Unicode()} - ), + typemap={"name": sqltypes.Unicode()}, + ) ) for row in rp: - if row['index_id'] in indexes: - indexes[row['index_id']]['column_names'].append(row['name']) + if row["index_id"] in indexes: + indexes[row["index_id"]]["column_names"].append(row["name"]) return list(indexes.values()) @reflection.cache @_db_plus_owner - def get_view_definition(self, connection, viewname, - dbname, owner, schema, **kw): + def get_view_definition( + self, connection, viewname, dbname, owner, schema, **kw + ): rp = connection.execute( sql.text( "select definition from sys.sql_modules as mod, " @@ -2155,11 +2422,15 @@ class MSDialect(default.DefaultDialect): "views.schema_id=sch.schema_id and " "views.name=:viewname and sch.name=:schname", bindparams=[ - sql.bindparam('viewname', viewname, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ] + sql.bindparam( + "viewname", + viewname, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], ) ) @@ -2173,12 +2444,15 @@ class MSDialect(default.DefaultDialect): # Get base columns columns = ischema.columns if owner: - whereclause = sql.and_(columns.c.table_name == tablename, - columns.c.table_schema == owner) + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) else: whereclause = columns.c.table_name == tablename - s = sql.select([columns], whereclause, - order_by=[columns.c.ordinal_position]) + s = sql.select( + [columns], whereclause, order_by=[columns.c.ordinal_position] + ) c = connection.execute(s) cols = [] @@ -2186,57 +2460,76 @@ class MSDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - (name, type, nullable, charlen, - numericprec, numericscale, default, collation) = ( + ( + name, + type, + nullable, + charlen, + numericprec, + numericscale, + default, + collation, + ) = ( row[columns.c.column_name], row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', + row[columns.c.is_nullable] == "YES", row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], row[columns.c.column_default], - row[columns.c.collation_name] + row[columns.c.collation_name], ) coltype = self.ischema_names.get(type, None) kwargs = {} - if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, - MSNText, MSBinary, MSVarBinary, - sqltypes.LargeBinary): + if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, + MSBinary, + MSVarBinary, + sqltypes.LargeBinary, + ): if charlen == -1: charlen = None - kwargs['length'] = charlen + kwargs["length"] = charlen if collation: - kwargs['collation'] = collation + kwargs["collation"] = collation if coltype is None: util.warn( - "Did not recognize type '%s' of column '%s'" % - (type, name)) + "Did not recognize type '%s' of column '%s'" % (type, name) + ) coltype = sqltypes.NULLTYPE else: - if issubclass(coltype, sqltypes.Numeric) and \ - coltype is not MSReal: - kwargs['scale'] = numericscale - kwargs['precision'] = numericprec + if ( + issubclass(coltype, sqltypes.Numeric) + and coltype is not MSReal + ): + kwargs["scale"] = numericscale + kwargs["precision"] = numericprec coltype = coltype(**kwargs) cdict = { - 'name': name, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': False, + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": False, } cols.append(cdict) # autoincrement and identity colmap = {} for col in cols: - colmap[col['name']] = col + colmap[col["name"]] = col # We also run an sp_columns to check for identity columns: - cursor = connection.execute("sp_columns @table_name = '%s', " - "@table_owner = '%s'" - % (tablename, owner)) + cursor = connection.execute( + "sp_columns @table_name = '%s', " + "@table_owner = '%s'" % (tablename, owner) + ) ic = None while True: row = cursor.fetchone() @@ -2245,10 +2538,10 @@ class MSDialect(default.DefaultDialect): (col_name, type_name) = row[3], row[5] if type_name.endswith("identity") and col_name in colmap: ic = col_name - colmap[col_name]['autoincrement'] = True - colmap[col_name]['dialect_options'] = { - 'mssql_identity_start': 1, - 'mssql_identity_increment': 1 + colmap[col_name]["autoincrement"] = True + colmap[col_name]["dialect_options"] = { + "mssql_identity_start": 1, + "mssql_identity_increment": 1, } break cursor.close() @@ -2262,64 +2555,74 @@ class MSDialect(default.DefaultDialect): row = cursor.first() if row is not None and row[0] is not None: - colmap[ic]['dialect_options'].update({ - 'mssql_identity_start': int(row[0]), - 'mssql_identity_increment': int(row[1]) - }) + colmap[ic]["dialect_options"].update( + { + "mssql_identity_start": int(row[0]), + "mssql_identity_increment": int(row[1]), + } + ) return cols @reflection.cache @_db_plus_owner - def get_pk_constraint(self, connection, tablename, - dbname, owner, schema, **kw): + def get_pk_constraint( + self, connection, tablename, dbname, owner, schema, **kw + ): pkeys = [] TC = ischema.constraints - C = ischema.key_constraints.alias('C') + C = ischema.key_constraints.alias("C") # Primary key constraints - s = sql.select([C.c.column_name, - TC.c.constraint_type, - C.c.constraint_name], - sql.and_(TC.c.constraint_name == C.c.constraint_name, - TC.c.table_schema == C.c.table_schema, - C.c.table_name == tablename, - C.c.table_schema == owner) - ) + s = sql.select( + [C.c.column_name, TC.c.constraint_type, C.c.constraint_name], + sql.and_( + TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner, + ), + ) c = connection.execute(s) constraint_name = None for row in c: - if 'PRIMARY' in row[TC.c.constraint_type.name]: + if "PRIMARY" in row[TC.c.constraint_type.name]: pkeys.append(row[0]) if constraint_name is None: constraint_name = row[C.c.constraint_name.name] - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache @_db_plus_owner - def get_foreign_keys(self, connection, tablename, - dbname, owner, schema, **kw): + def get_foreign_keys( + self, connection, tablename, dbname, owner, schema, **kw + ): RR = ischema.ref_constraints - C = ischema.key_constraints.alias('C') - R = ischema.key_constraints.alias('R') + C = ischema.key_constraints.alias("C") + R = ischema.key_constraints.alias("R") # Foreign key constraints - s = sql.select([C.c.column_name, - R.c.table_schema, R.c.table_name, R.c.column_name, - RR.c.constraint_name, RR.c.match_option, - RR.c.update_rule, - RR.c.delete_rule], - sql.and_(C.c.table_name == tablename, - C.c.table_schema == owner, - RR.c.constraint_schema == C.c.table_schema, - C.c.constraint_name == RR.c.constraint_name, - R.c.constraint_name == - RR.c.unique_constraint_name, - R.c.constraint_schema == - RR.c.unique_constraint_schema, - C.c.ordinal_position == R.c.ordinal_position - ), - order_by=[RR.c.constraint_name, R.c.ordinal_position] - ) + s = sql.select( + [ + C.c.column_name, + R.c.table_schema, + R.c.table_name, + R.c.column_name, + RR.c.constraint_name, + RR.c.match_option, + RR.c.update_rule, + RR.c.delete_rule, + ], + sql.and_( + C.c.table_name == tablename, + C.c.table_schema == owner, + RR.c.constraint_schema == C.c.table_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + R.c.constraint_schema == RR.c.unique_constraint_schema, + C.c.ordinal_position == R.c.ordinal_position, + ), + order_by=[RR.c.constraint_name, R.c.ordinal_position], + ) # group rows by constraint ID, to handle multi-column FKs fkeys = [] @@ -2327,11 +2630,11 @@ class MSDialect(default.DefaultDialect): def fkey_rec(): return { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [] + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], } fkeys = util.defaultdict(fkey_rec) @@ -2340,17 +2643,18 @@ class MSDialect(default.DefaultDialect): scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r rec = fkeys[rfknm] - rec['name'] = rfknm - if not rec['referred_table']: - rec['referred_table'] = rtbl + rec["name"] = rfknm + if not rec["referred_table"]: + rec["referred_table"] = rtbl if schema is not None or owner != rschema: if dbname: rschema = dbname + "." + rschema - rec['referred_schema'] = rschema + rec["referred_schema"] = rschema - local_cols, remote_cols = \ - rec['constrained_columns'],\ - rec['referred_columns'] + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) local_cols.append(scol) remote_cols.append(rcol) |
