diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/dialects/mssql/base.py | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 1238 |
1 files changed, 771 insertions, 467 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9269225d3..161297015 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -655,9 +655,22 @@ from ...sql import compiler, expression, util as sql_util, quoted_name from ... import engine from ...engine import reflection, default from ... import types as sqltypes -from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ - FLOAT, DATETIME, DATE, BINARY, \ - TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR +from ...types import ( + INTEGER, + BIGINT, + SMALLINT, + DECIMAL, + NUMERIC, + FLOAT, + DATETIME, + DATE, + BINARY, + TEXT, + VARCHAR, + NVARCHAR, + CHAR, + NCHAR, +) from ...util import update_wrapper @@ -672,48 +685,202 @@ MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) RESERVED_WORDS = set( - ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', - 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', - 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', - 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', - 'containstable', 'continue', 'convert', 'create', 'cross', 'current', - 'current_date', 'current_time', 'current_timestamp', 'current_user', - 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', - 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', - 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', - 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', - 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', - 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', - 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', - 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', - 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', - 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', - 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', - 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', - 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', - 'reconfigure', 'references', 'replication', 'restore', 'restrict', - 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', - 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', - 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', - 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', - 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', - 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', - 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', - 'writetext', - ]) + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", + ] +) class REAL(sqltypes.REAL): - __visit_name__ = 'REAL' + __visit_name__ = "REAL" def __init__(self, **kw): # REAL is a synonym for FLOAT(24) on SQL server - kw['precision'] = 24 + kw["precision"] = 24 super(REAL, self).__init__(**kw) class TINYINT(sqltypes.Integer): - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" # MSSQL DATE/TIME types have varied behavior, sometimes returning @@ -721,14 +888,15 @@ class TINYINT(sqltypes.Integer): # filter bind parameters into datetime objects (required by pyodbc, # not sure about other dialects). -class _MSDate(sqltypes.Date): +class _MSDate(sqltypes.Date): def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process _reg = re.compile(r"(\d+)-(\d+)-(\d+)") @@ -741,18 +909,16 @@ class _MSDate(sqltypes.Date): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a date value" % (value, )) - return datetime.date(*[ - int(x or 0) - for x in m.groups() - ]) + "could not parse %r as a date value" % (value,) + ) + return datetime.date(*[int(x or 0) for x in m.groups()]) else: return value + return process class TIME(sqltypes.TIME): - def __init__(self, precision=None, **kwargs): self.precision = precision super(TIME, self).__init__() @@ -763,10 +929,12 @@ class TIME(sqltypes.TIME): def process(value): if isinstance(value, datetime.datetime): value = datetime.datetime.combine( - self.__zero_date, value.time()) + self.__zero_date, value.time() + ) elif isinstance(value, datetime.time): value = datetime.datetime.combine(self.__zero_date, value) return value + return process _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") @@ -779,24 +947,26 @@ class TIME(sqltypes.TIME): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a time value" % (value, )) - return datetime.time(*[ - int(x or 0) - for x in m.groups()]) + "could not parse %r as a time value" % (value,) + ) + return datetime.time(*[int(x or 0) for x in m.groups()]) else: return value + return process + + _MSTime = TIME class _DateTimeBase(object): - def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process @@ -805,11 +975,11 @@ class _MSDateTime(_DateTimeBase, sqltypes.DateTime): class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'SMALLDATETIME' + __visit_name__ = "SMALLDATETIME" class DATETIME2(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'DATETIME2' + __visit_name__ = "DATETIME2" def __init__(self, precision=None, **kw): super(DATETIME2, self).__init__(**kw) @@ -818,7 +988,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime): # TODO: is this not an Interval ? class DATETIMEOFFSET(sqltypes.TypeEngine): - __visit_name__ = 'DATETIMEOFFSET' + __visit_name__ = "DATETIMEOFFSET" def __init__(self, precision=None, **kwargs): self.precision = precision @@ -847,7 +1017,7 @@ class TIMESTAMP(sqltypes._Binary): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" # expected by _Binary to be present length = None @@ -866,12 +1036,14 @@ class TIMESTAMP(sqltypes._Binary): def result_processor(self, dialect, coltype): super_ = super(TIMESTAMP, self).result_processor(dialect, coltype) if self.convert_int: + def process(value): value = super_(value) if value is not None: # https://stackoverflow.com/a/30403242/34549 - value = int(codecs.encode(value, 'hex'), 16) + value = int(codecs.encode(value, "hex"), 16) return value + return process else: return super_ @@ -898,7 +1070,7 @@ class ROWVERSION(TIMESTAMP): """ - __visit_name__ = 'ROWVERSION' + __visit_name__ = "ROWVERSION" class NTEXT(sqltypes.UnicodeText): @@ -906,7 +1078,7 @@ class NTEXT(sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" - __visit_name__ = 'NTEXT' + __visit_name__ = "NTEXT" class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): @@ -925,11 +1097,12 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): """ - __visit_name__ = 'VARBINARY' + + __visit_name__ = "VARBINARY" class IMAGE(sqltypes.LargeBinary): - __visit_name__ = 'IMAGE' + __visit_name__ = "IMAGE" class XML(sqltypes.Text): @@ -943,19 +1116,20 @@ class XML(sqltypes.Text): .. versionadded:: 1.1.11 """ - __visit_name__ = 'XML' + + __visit_name__ = "XML" class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" class MONEY(sqltypes.TypeEngine): - __visit_name__ = 'MONEY' + __visit_name__ = "MONEY" class SMALLMONEY(sqltypes.TypeEngine): - __visit_name__ = 'SMALLMONEY' + __visit_name__ = "SMALLMONEY" class UNIQUEIDENTIFIER(sqltypes.TypeEngine): @@ -963,7 +1137,8 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine): class SQL_VARIANT(sqltypes.TypeEngine): - __visit_name__ = 'SQL_VARIANT' + __visit_name__ = "SQL_VARIANT" + # old names. MSDateTime = _MSDateTime @@ -990,36 +1165,36 @@ MSUniqueIdentifier = UNIQUEIDENTIFIER MSVariant = SQL_VARIANT ischema_names = { - 'int': INTEGER, - 'bigint': BIGINT, - 'smallint': SMALLINT, - 'tinyint': TINYINT, - 'varchar': VARCHAR, - 'nvarchar': NVARCHAR, - 'char': CHAR, - 'nchar': NCHAR, - 'text': TEXT, - 'ntext': NTEXT, - 'decimal': DECIMAL, - 'numeric': NUMERIC, - 'float': FLOAT, - 'datetime': DATETIME, - 'datetime2': DATETIME2, - 'datetimeoffset': DATETIMEOFFSET, - 'date': DATE, - 'time': TIME, - 'smalldatetime': SMALLDATETIME, - 'binary': BINARY, - 'varbinary': VARBINARY, - 'bit': BIT, - 'real': REAL, - 'image': IMAGE, - 'xml': XML, - 'timestamp': TIMESTAMP, - 'money': MONEY, - 'smallmoney': SMALLMONEY, - 'uniqueidentifier': UNIQUEIDENTIFIER, - 'sql_variant': SQL_VARIANT, + "int": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "tinyint": TINYINT, + "varchar": VARCHAR, + "nvarchar": NVARCHAR, + "char": CHAR, + "nchar": NCHAR, + "text": TEXT, + "ntext": NTEXT, + "decimal": DECIMAL, + "numeric": NUMERIC, + "float": FLOAT, + "datetime": DATETIME, + "datetime2": DATETIME2, + "datetimeoffset": DATETIMEOFFSET, + "date": DATE, + "time": TIME, + "smalldatetime": SMALLDATETIME, + "binary": BINARY, + "varbinary": VARBINARY, + "bit": BIT, + "real": REAL, + "image": IMAGE, + "xml": XML, + "timestamp": TIMESTAMP, + "money": MONEY, + "smallmoney": SMALLMONEY, + "uniqueidentifier": UNIQUEIDENTIFIER, + "sql_variant": SQL_VARIANT, } @@ -1030,8 +1205,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): """ - if getattr(type_, 'collation', None): - collation = 'COLLATE %s' % type_.collation + if getattr(type_, "collation", None): + collation = "COLLATE %s" % type_.collation else: collation = None @@ -1041,15 +1216,14 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): if length: spec = spec + "(%s)" % length - return ' '.join([c for c in (spec, collation) - if c is not None]) + return " ".join([c for c in (spec, collation) if c is not None]) def visit_FLOAT(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is None: return "FLOAT" else: - return "FLOAT(%(precision)s)" % {'precision': precision} + return "FLOAT(%(precision)s)" % {"precision": precision} def visit_TINYINT(self, type_, **kw): return "TINYINT" @@ -1061,7 +1235,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "DATETIMEOFFSET" def visit_TIME(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "TIME(%s)" % precision else: @@ -1074,7 +1248,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "ROWVERSION" def visit_DATETIME2(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "DATETIME2(%s)" % precision else: @@ -1105,7 +1279,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("TEXT", type_) def visit_VARCHAR(self, type_, **kw): - return self._extend("VARCHAR", type_, length=type_.length or 'max') + return self._extend("VARCHAR", type_, length=type_.length or "max") def visit_CHAR(self, type_, **kw): return self._extend("CHAR", type_) @@ -1114,7 +1288,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("NCHAR", type_) def visit_NVARCHAR(self, type_, **kw): - return self._extend("NVARCHAR", type_, length=type_.length or 'max') + return self._extend("NVARCHAR", type_, length=type_.length or "max") def visit_date(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: @@ -1141,10 +1315,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "XML" def visit_VARBINARY(self, type_, **kw): - return self._extend( - "VARBINARY", - type_, - length=type_.length or 'max') + return self._extend("VARBINARY", type_, length=type_.length or "max") def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) @@ -1156,13 +1327,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "MONEY" def visit_SMALLMONEY(self, type_, **kw): - return 'SMALLMONEY' + return "SMALLMONEY" def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" def visit_SQL_VARIANT(self, type_, **kw): - return 'SQL_VARIANT' + return "SQL_VARIANT" class MSExecutionContext(default.DefaultExecutionContext): @@ -1186,41 +1357,44 @@ class MSExecutionContext(default.DefaultExecutionContext): insert_has_sequence = seq_column is not None if insert_has_sequence: - self._enable_identity_insert = \ - seq_column.key in self.compiled_parameters[0] or \ - ( - self.compiled.statement.parameters and ( - ( - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters[0] - ) or ( - not - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters - ) + self._enable_identity_insert = seq_column.key in self.compiled_parameters[ + 0 + ] or ( + self.compiled.statement.parameters + and ( + ( + self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters[0] + ) + or ( + not self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters ) ) + ) else: self._enable_identity_insert = False - self._select_lastrowid = not self.compiled.inline and \ - insert_has_sequence and \ - not self.compiled.returning and \ - not self._enable_identity_insert and \ - not self.executemany + self._select_lastrowid = ( + not self.compiled.inline + and insert_has_sequence + and not self.compiled.returning + and not self._enable_identity_insert + and not self.executemany + ) if self._enable_identity_insert: self.root_connection._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)), + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(tbl) + ), (), - self) + self, + ) def post_exec(self): """Disable IDENTITY_INSERT if enabled.""" @@ -1230,29 +1404,35 @@ class MSExecutionContext(default.DefaultExecutionContext): if self.dialect.use_scope_identity: conn._cursor_execute( self.cursor, - "SELECT scope_identity() AS lastrowid", (), self) + "SELECT scope_identity() AS lastrowid", + (), + self, + ) else: - conn._cursor_execute(self.cursor, - "SELECT @@identity AS lastrowid", - (), - self) + conn._cursor_execute( + self.cursor, "SELECT @@identity AS lastrowid", (), self + ) # fetchall() ensures the cursor is consumed without closing it row = self.cursor.fetchall()[0] self._lastrowid = int(row[0]) - if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.returning: + if ( + self.isinsert or self.isupdate or self.isdelete + ) and self.compiled.returning: self._result_proxy = engine.FullyBufferedResultProxy(self) if self._enable_identity_insert: conn._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table)), + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ), (), - self) + self, + ) def get_lastrowid(self): return self._lastrowid @@ -1262,9 +1442,12 @@ class MSExecutionContext(default.DefaultExecutionContext): try: self.cursor.execute( self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table))) + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ) + ) except Exception: pass @@ -1281,11 +1464,12 @@ class MSSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'microseconds': 'microsecond' - }) + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "microseconds": "microsecond", + }, + ) def __init__(self, *args, **kwargs): self.tablealiases = {} @@ -1298,6 +1482,7 @@ class MSSQLCompiler(compiler.SQLCompiler): else: super_ = getattr(super(MSSQLCompiler, self), fn.__name__) return super_(*arg, **kw) + return decorate def visit_now_func(self, fn, **kw): @@ -1313,20 +1498,22 @@ class MSSQLCompiler(compiler.SQLCompiler): return "LEN%s" % self.function_argspec(fn, **kw) def visit_concat_op_binary(self, binary, operator, **kw): - return "%s + %s" % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "%s + %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def visit_match_op_binary(self, binary, operator, **kw): return "CONTAINS (%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def get_select_precolumns(self, select, **kw): """ MS-SQL puts TOP, it's version of LIMIT here """ @@ -1345,7 +1532,8 @@ class MSSQLCompiler(compiler.SQLCompiler): return s else: return compiler.SQLCompiler.get_select_precolumns( - self, select, **kw) + self, select, **kw + ) def get_from_hint_text(self, table, text): return text @@ -1363,20 +1551,21 @@ class MSSQLCompiler(compiler.SQLCompiler): """ if ( - ( - not select._simple_int_limit and - select._limit_clause is not None - ) or ( - select._offset_clause is not None and - not select._simple_int_offset or select._offset + (not select._simple_int_limit and select._limit_clause is not None) + or ( + select._offset_clause is not None + and not select._simple_int_offset + or select._offset ) - ) and not getattr(select, '_mssql_visit', None): + ) and not getattr(select, "_mssql_visit", None): # to use ROW_NUMBER(), an ORDER BY is required. if not select._order_by_clause.clauses: - raise exc.CompileError('MSSQL requires an order_by when ' - 'using an OFFSET or a non-simple ' - 'LIMIT clause') + raise exc.CompileError( + "MSSQL requires an order_by when " + "using an OFFSET or a non-simple " + "LIMIT clause" + ) _order_by_clauses = [ sql_util.unwrap_label_reference(elem) @@ -1385,24 +1574,31 @@ class MSSQLCompiler(compiler.SQLCompiler): limit_clause = select._limit_clause offset_clause = select._offset_clause - kwargs['select_wraps_for'] = select + kwargs["select_wraps_for"] = select select = select._generate() select._mssql_visit = True - select = select.column( - sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) - .label("mssql_rn")).order_by(None).alias() + select = ( + select.column( + sql.func.ROW_NUMBER() + .over(order_by=_order_by_clauses) + .label("mssql_rn") + ) + .order_by(None) + .alias() + ) - mssql_rn = sql.column('mssql_rn') - limitselect = sql.select([c for c in select.c if - c.key != 'mssql_rn']) + mssql_rn = sql.column("mssql_rn") + limitselect = sql.select( + [c for c in select.c if c.key != "mssql_rn"] + ) if offset_clause is not None: limitselect.append_whereclause(mssql_rn > offset_clause) if limit_clause is not None: limitselect.append_whereclause( - mssql_rn <= (limit_clause + offset_clause)) + mssql_rn <= (limit_clause + offset_clause) + ) else: - limitselect.append_whereclause( - mssql_rn <= (limit_clause)) + limitselect.append_whereclause(mssql_rn <= (limit_clause)) return self.process(limitselect, **kwargs) else: return compiler.SQLCompiler.visit_select(self, select, **kwargs) @@ -1422,35 +1618,38 @@ class MSSQLCompiler(compiler.SQLCompiler): @_with_legacy_schema_aliasing def visit_alias(self, alias, **kw): # translate for schema-qualified table aliases - kw['mssql_aliased'] = alias.original + kw["mssql_aliased"] = alias.original return super(MSSQLCompiler, self).visit_alias(alias, **kw) @_with_legacy_schema_aliasing def visit_column(self, column, add_to_result_map=None, **kw): - if column.table is not None and \ - (not self.isupdate and not self.isdelete) or \ - self.is_subquery(): + if ( + column.table is not None + and (not self.isupdate and not self.isdelete) + or self.is_subquery() + ): # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: converted = expression._corresponding_column_or_error( - t, column) + t, column + ) if add_to_result_map is not None: add_to_result_map( column.name, column.name, (column, column.name, column.key), - column.type + column.type, ) - return super(MSSQLCompiler, self).\ - visit_column(converted, **kw) + return super(MSSQLCompiler, self).visit_column(converted, **kw) return super(MSSQLCompiler, self).visit_column( - column, add_to_result_map=add_to_result_map, **kw) + column, add_to_result_map=add_to_result_map, **kw + ) def _schema_aliased_table(self, table): - if getattr(table, 'schema', None) is not None: + if getattr(table, "schema", None) is not None: if table not in self.tablealiases: self.tablealiases[table] = table.alias() return self.tablealiases[table] @@ -1459,16 +1658,17 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART(%s, %s)' % \ - (field, self.process(extract.expr, **kw)) + return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) def visit_savepoint(self, savepoint_stmt): - return "SAVE TRANSACTION %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_rollback_to_savepoint(self, savepoint_stmt): - return ("ROLLBACK TRANSACTION %s" - % self.preparer.format_savepoint(savepoint_stmt)) + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where @@ -1481,10 +1681,11 @@ class MSSQLCompiler(compiler.SQLCompiler): and not isinstance(binary.right, expression.BindParameter) ): return self.process( - expression.BinaryExpression(binary.right, - binary.left, - binary.operator), - **kwargs) + expression.BinaryExpression( + binary.right, binary.left, binary.operator + ), + **kwargs + ) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def returning_clause(self, stmt, returning_cols): @@ -1497,12 +1698,13 @@ class MSSQLCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(target) columns = [ - self._label_select_column(None, adapter.traverse(c), - True, False, {}) + self._label_select_column( + None, adapter.traverse(c), True, False, {} + ) for c in expression._select_iterables(returning_cols) ] - return 'OUTPUT ' + ', '.join(columns) + return "OUTPUT " + ", ".join(columns) def get_cte_preamble(self, recursive): # SQL Server finds it too inconvenient to accept @@ -1515,13 +1717,14 @@ class MSSQLCompiler(compiler.SQLCompiler): if isinstance(column, expression.Function): return column.label(None) else: - return super(MSSQLCompiler, self).\ - label_select_column(select, column, asfrom) + return super(MSSQLCompiler, self).label_select_column( + select, column, asfrom + ) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which # SQLAlchemy doesn't use - return '' + return "" def order_by_clause(self, select, **kw): order_by = self.process(select._order_by_clause, **kw) @@ -1532,10 +1735,9 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the UPDATE..FROM clause specific to MSSQL. In MSSQL, if the UPDATE statement involves an alias of the table to @@ -1543,13 +1745,12 @@ class MSSQLCompiler(compiler.SQLCompiler): well. Otherwise, it is optional. Here, we add it regardless. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1558,20 +1759,21 @@ class MSSQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. FROM clause specific to MSSQL. Yes, it has the FROM keyword twice. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, type_): - return 'SELECT 1 WHERE 1!=1' + return "SELECT 1 WHERE 1!=1" class MSSQLStrictCompiler(MSSQLCompiler): @@ -1583,20 +1785,21 @@ class MSSQLStrictCompiler(MSSQLCompiler): binds are used. """ + ansi_bind_rules = True def visit_in_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_notin_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s NOT IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def render_literal_value(self, value, type_): @@ -1615,23 +1818,28 @@ class MSSQLStrictCompiler(MSSQLCompiler): # SQL Server wants single quotes around the date string. return "'" + str(value) + "'" else: - return super(MSSQLStrictCompiler, self).\ - render_literal_value(value, type_) + return super(MSSQLStrictCompiler, self).render_literal_value( + value, type_ + ) class MSDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): colspec = ( - self.preparer.format_column(column) + " " + self.preparer.format_column(column) + + " " + self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) ) if column.nullable is not None: - if not column.nullable or column.primary_key or \ - isinstance(column.default, sa_schema.Sequence) or \ - column.autoincrement is True: + if ( + not column.nullable + or column.primary_key + or isinstance(column.default, sa_schema.Sequence) + or column.autoincrement is True + ): colspec += " NOT NULL" else: colspec += " NULL" @@ -1639,15 +1847,18 @@ class MSDDLCompiler(compiler.DDLCompiler): if column.table is None: raise exc.CompileError( "mssql requires Table-bound columns " - "in order to generate DDL") + "in order to generate DDL" + ) # install an IDENTITY Sequence if we either a sequence or an implicit # IDENTITY column if isinstance(column.default, sa_schema.Sequence): - if (column.default.start is not None or - column.default.increment is not None or - column is not column.table._autoincrement_column): + if ( + column.default.start is not None + or column.default.increment is not None + or column is not column.table._autoincrement_column + ): util.warn_deprecated( "Use of Sequence with SQL Server in order to affect the " "parameters of the IDENTITY value is deprecated, as " @@ -1655,18 +1866,23 @@ class MSDDLCompiler(compiler.DDLCompiler): "will correspond to an actual SQL Server " "CREATE SEQUENCE in " "a future release. Please use the mssql_identity_start " - "and mssql_identity_increment parameters.") + "and mssql_identity_increment parameters." + ) if column.default.start == 0: start = 0 else: start = column.default.start or 1 - colspec += " IDENTITY(%s,%s)" % (start, - column.default.increment or 1) - elif column is column.table._autoincrement_column or \ - column.autoincrement is True: - start = column.dialect_options['mssql']['identity_start'] - increment = column.dialect_options['mssql']['identity_increment'] + colspec += " IDENTITY(%s,%s)" % ( + start, + column.default.increment or 1, + ) + elif ( + column is column.table._autoincrement_column + or column.autoincrement is True + ): + start = column.dialect_options["mssql"]["identity_start"] + increment = column.dialect_options["mssql"]["identity_increment"] colspec += " IDENTITY(%s,%s)" % (start, increment) else: default = self.get_column_default_string(column) @@ -1684,84 +1900,88 @@ class MSDDLCompiler(compiler.DDLCompiler): text += "UNIQUE " # handle clustering option - clustered = index.dialect_options['mssql']['clustered'] + clustered = index.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table), - ', '.join( - self.sql_compiler.process(expr, - include_table=False, - literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table(index.table), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) # handle other included columns - if index.dialect_options['mssql']['include']: - inclusions = [index.table.c[col] - if isinstance(col, util.string_types) else col - for col in - index.dialect_options['mssql']['include'] - ] + if index.dialect_options["mssql"]["include"]: + inclusions = [ + index.table.c[col] + if isinstance(col, util.string_types) + else col + for col in index.dialect_options["mssql"]["include"] + ] - text += " INCLUDE (%s)" \ - % ', '.join([preparer.quote(c.name) - for c in inclusions]) + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) return text def visit_drop_index(self, drop): return "\nDROP INDEX %s ON %s" % ( self._prepared_index_name(drop.element, include_schema=False), - self.preparer.format_table(drop.element.table) + self.preparer.format_table(drop.element.table), ) def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "PRIMARY KEY " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "UNIQUE " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text @@ -1771,8 +1991,11 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(MSIdentifierPreparer, self).__init__( - dialect, initial_quote='[', - final_quote=']', quote_case_sensitive_collations=False) + dialect, + initial_quote="[", + final_quote="]", + quote_case_sensitive_collations=False, + ) def _escape_identifier(self, value): return value @@ -1783,7 +2006,9 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): dbname, owner = _schema_elements(schema) if dbname: result = "%s.%s" % ( - self.quote(dbname, force), self.quote(owner, force)) + self.quote(dbname, force), + self.quote(owner, force), + ) elif owner: result = self.quote(owner, force) else: @@ -1794,16 +2019,37 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def _db_plus_owner_listing(fn): def wrap(dialect, connection, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) def _db_plus_owner(fn): def wrap(dialect, connection, tablename, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - tablename, dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + tablename, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) @@ -1837,9 +2083,9 @@ def _schema_elements(schema): for token in re.split(r"(\[|\]|\.)", schema): if not token: continue - if token == '[': + if token == "[": bracket = True - elif token == ']': + elif token == "]": bracket = False elif not bracket and token == ".": push.append(symbol) @@ -1857,7 +2103,7 @@ def _schema_elements(schema): class MSDialect(default.DefaultDialect): - name = 'mssql' + name = "mssql" supports_default_values = True supports_empty_insert = False execution_ctx_cls = MSExecutionContext @@ -1871,9 +2117,9 @@ class MSDialect(default.DefaultDialect): sqltypes.Time: TIME, } - engine_config_types = default.DefaultDialect.engine_config_types.union([ - ('legacy_schema_aliasing', util.asbool), - ]) + engine_config_types = default.DefaultDialect.engine_config_types.union( + [("legacy_schema_aliasing", util.asbool)] + ) ischema_names = ischema_names @@ -1890,36 +2136,30 @@ class MSDialect(default.DefaultDialect): preparer = MSIdentifierPreparer construct_arguments = [ - (sa_schema.PrimaryKeyConstraint, { - "clustered": None - }), - (sa_schema.UniqueConstraint, { - "clustered": None - }), - (sa_schema.Index, { - "clustered": None, - "include": None - }), - (sa_schema.Column, { - "identity_start": 1, - "identity_increment": 1 - }) + (sa_schema.PrimaryKeyConstraint, {"clustered": None}), + (sa_schema.UniqueConstraint, {"clustered": None}), + (sa_schema.Index, {"clustered": None, "include": None}), + (sa_schema.Column, {"identity_start": 1, "identity_increment": 1}), ] - def __init__(self, - query_timeout=None, - use_scope_identity=True, - max_identifier_length=None, - schema_name="dbo", - isolation_level=None, - deprecate_large_types=None, - legacy_schema_aliasing=False, **opts): + def __init__( + self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", + isolation_level=None, + deprecate_large_types=None, + legacy_schema_aliasing=False, + **opts + ): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = use_scope_identity - self.max_identifier_length = int(max_identifier_length or 0) or \ - self.max_identifier_length + self.max_identifier_length = ( + int(max_identifier_length or 0) or self.max_identifier_length + ) self.deprecate_large_types = deprecate_large_types self.legacy_schema_aliasing = legacy_schema_aliasing @@ -1936,27 +2176,33 @@ class MSDialect(default.DefaultDialect): # SQL Server does not support RELEASE SAVEPOINT pass - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'SNAPSHOT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() - cursor.execute( - "SET TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level) cursor.close() def get_isolation_level(self, connection): if self.server_version_info < MS_2005_VERSION: raise NotImplementedError( - "Can't fetch isolation level prior to SQL Server 2005") + "Can't fetch isolation level prior to SQL Server 2005" + ) last_error = None @@ -1964,7 +2210,8 @@ class MSDialect(default.DefaultDialect): for view in views: cursor = connection.cursor() try: - cursor.execute(""" + cursor.execute( + """ SELECT CASE transaction_isolation_level WHEN 0 THEN NULL WHEN 1 THEN 'READ UNCOMMITTED' @@ -1974,7 +2221,9 @@ class MSDialect(default.DefaultDialect): WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL FROM %s where session_id = @@SPID - """ % view) + """ + % view + ) val = cursor.fetchone()[0] except self.dbapi.Error as err: # Python3 scoping rules @@ -1987,7 +2236,8 @@ class MSDialect(default.DefaultDialect): else: util.warn( "Could not fetch transaction isolation level, " - "tried views: %s; final error was: %s" % (views, last_error)) + "tried views: %s; final error was: %s" % (views, last_error) + ) raise NotImplementedError( "Can't fetch isolation level on this particular " @@ -2000,8 +2250,10 @@ class MSDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None @@ -2010,16 +2262,20 @@ class MSDialect(default.DefaultDialect): if self.server_version_info[0] not in list(range(8, 17)): util.warn( "Unrecognized server version info '%s'. Some SQL Server " - "features may not function properly." % - ".".join(str(x) for x in self.server_version_info)) - if self.server_version_info >= MS_2005_VERSION and \ - 'implicit_returning' not in self.__dict__: + "features may not function properly." + % ".".join(str(x) for x in self.server_version_info) + ) + if ( + self.server_version_info >= MS_2005_VERSION + and "implicit_returning" not in self.__dict__ + ): self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True if self.deprecate_large_types is None: - self.deprecate_large_types = \ + self.deprecate_large_types = ( self.server_version_info >= MS_2012_VERSION + ) def _get_default_schema_name(self, connection): if self.server_version_info < MS_2005_VERSION: @@ -2039,17 +2295,19 @@ class MSDialect(default.DefaultDialect): whereclause = columns.c.table_name == tablename if owner: - whereclause = sql.and_(whereclause, - columns.c.table_schema == owner) + whereclause = sql.and_( + whereclause, columns.c.table_schema == owner + ) s = sql.select([columns], whereclause) c = connection.execute(s) return c.first() is not None @reflection.cache def get_schema_names(self, connection, **kw): - s = sql.select([ischema.schemata.c.schema_name], - order_by=[ischema.schemata.c.schema_name] - ) + s = sql.select( + [ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name], + ) schema_names = [r[0] for r in connection.execute(s)] return schema_names @@ -2057,12 +2315,13 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_table_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'BASE TABLE' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "BASE TABLE", + ), + order_by=[tables.c.table_name], ) table_names = [r[0] for r in connection.execute(s)] return table_names @@ -2071,12 +2330,12 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_view_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'VIEW' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, tables.c.table_type == "VIEW" + ), + order_by=[tables.c.table_name], ) view_names = [r[0] for r in connection.execute(s)] return view_names @@ -2090,30 +2349,33 @@ class MSDialect(default.DefaultDialect): return [] rp = connection.execute( - sql.text("select ind.index_id, ind.is_unique, ind.name " - "from sys.indexes as ind join sys.tables as tab on " - "ind.object_id=tab.object_id " - "join sys.schemas as sch on sch.schema_id=tab.schema_id " - "where tab.name = :tabname " - "and sch.name=:schname " - "and ind.is_primary_key=0 and ind.type != 0", - bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ], - typemap={ - 'name': sqltypes.Unicode() - } - ) + sql.text( + "select ind.index_id, ind.is_unique, ind.name " + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0 and ind.type != 0", + bindparams=[ + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], + typemap={"name": sqltypes.Unicode()}, + ) ) indexes = {} for row in rp: - indexes[row['index_id']] = { - 'name': row['name'], - 'unique': row['is_unique'] == 1, - 'column_names': [] + indexes[row["index_id"]] = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], } rp = connection.execute( sql.text( @@ -2127,24 +2389,29 @@ class MSDialect(default.DefaultDialect): "where tab.name=:tabname " "and sch.name=:schname", bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), ], - typemap={'name': sqltypes.Unicode()} - ), + typemap={"name": sqltypes.Unicode()}, + ) ) for row in rp: - if row['index_id'] in indexes: - indexes[row['index_id']]['column_names'].append(row['name']) + if row["index_id"] in indexes: + indexes[row["index_id"]]["column_names"].append(row["name"]) return list(indexes.values()) @reflection.cache @_db_plus_owner - def get_view_definition(self, connection, viewname, - dbname, owner, schema, **kw): + def get_view_definition( + self, connection, viewname, dbname, owner, schema, **kw + ): rp = connection.execute( sql.text( "select definition from sys.sql_modules as mod, " @@ -2155,11 +2422,15 @@ class MSDialect(default.DefaultDialect): "views.schema_id=sch.schema_id and " "views.name=:viewname and sch.name=:schname", bindparams=[ - sql.bindparam('viewname', viewname, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ] + sql.bindparam( + "viewname", + viewname, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], ) ) @@ -2173,12 +2444,15 @@ class MSDialect(default.DefaultDialect): # Get base columns columns = ischema.columns if owner: - whereclause = sql.and_(columns.c.table_name == tablename, - columns.c.table_schema == owner) + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) else: whereclause = columns.c.table_name == tablename - s = sql.select([columns], whereclause, - order_by=[columns.c.ordinal_position]) + s = sql.select( + [columns], whereclause, order_by=[columns.c.ordinal_position] + ) c = connection.execute(s) cols = [] @@ -2186,57 +2460,76 @@ class MSDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - (name, type, nullable, charlen, - numericprec, numericscale, default, collation) = ( + ( + name, + type, + nullable, + charlen, + numericprec, + numericscale, + default, + collation, + ) = ( row[columns.c.column_name], row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', + row[columns.c.is_nullable] == "YES", row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], row[columns.c.column_default], - row[columns.c.collation_name] + row[columns.c.collation_name], ) coltype = self.ischema_names.get(type, None) kwargs = {} - if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, - MSNText, MSBinary, MSVarBinary, - sqltypes.LargeBinary): + if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, + MSBinary, + MSVarBinary, + sqltypes.LargeBinary, + ): if charlen == -1: charlen = None - kwargs['length'] = charlen + kwargs["length"] = charlen if collation: - kwargs['collation'] = collation + kwargs["collation"] = collation if coltype is None: util.warn( - "Did not recognize type '%s' of column '%s'" % - (type, name)) + "Did not recognize type '%s' of column '%s'" % (type, name) + ) coltype = sqltypes.NULLTYPE else: - if issubclass(coltype, sqltypes.Numeric) and \ - coltype is not MSReal: - kwargs['scale'] = numericscale - kwargs['precision'] = numericprec + if ( + issubclass(coltype, sqltypes.Numeric) + and coltype is not MSReal + ): + kwargs["scale"] = numericscale + kwargs["precision"] = numericprec coltype = coltype(**kwargs) cdict = { - 'name': name, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': False, + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": False, } cols.append(cdict) # autoincrement and identity colmap = {} for col in cols: - colmap[col['name']] = col + colmap[col["name"]] = col # We also run an sp_columns to check for identity columns: - cursor = connection.execute("sp_columns @table_name = '%s', " - "@table_owner = '%s'" - % (tablename, owner)) + cursor = connection.execute( + "sp_columns @table_name = '%s', " + "@table_owner = '%s'" % (tablename, owner) + ) ic = None while True: row = cursor.fetchone() @@ -2245,10 +2538,10 @@ class MSDialect(default.DefaultDialect): (col_name, type_name) = row[3], row[5] if type_name.endswith("identity") and col_name in colmap: ic = col_name - colmap[col_name]['autoincrement'] = True - colmap[col_name]['dialect_options'] = { - 'mssql_identity_start': 1, - 'mssql_identity_increment': 1 + colmap[col_name]["autoincrement"] = True + colmap[col_name]["dialect_options"] = { + "mssql_identity_start": 1, + "mssql_identity_increment": 1, } break cursor.close() @@ -2262,64 +2555,74 @@ class MSDialect(default.DefaultDialect): row = cursor.first() if row is not None and row[0] is not None: - colmap[ic]['dialect_options'].update({ - 'mssql_identity_start': int(row[0]), - 'mssql_identity_increment': int(row[1]) - }) + colmap[ic]["dialect_options"].update( + { + "mssql_identity_start": int(row[0]), + "mssql_identity_increment": int(row[1]), + } + ) return cols @reflection.cache @_db_plus_owner - def get_pk_constraint(self, connection, tablename, - dbname, owner, schema, **kw): + def get_pk_constraint( + self, connection, tablename, dbname, owner, schema, **kw + ): pkeys = [] TC = ischema.constraints - C = ischema.key_constraints.alias('C') + C = ischema.key_constraints.alias("C") # Primary key constraints - s = sql.select([C.c.column_name, - TC.c.constraint_type, - C.c.constraint_name], - sql.and_(TC.c.constraint_name == C.c.constraint_name, - TC.c.table_schema == C.c.table_schema, - C.c.table_name == tablename, - C.c.table_schema == owner) - ) + s = sql.select( + [C.c.column_name, TC.c.constraint_type, C.c.constraint_name], + sql.and_( + TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner, + ), + ) c = connection.execute(s) constraint_name = None for row in c: - if 'PRIMARY' in row[TC.c.constraint_type.name]: + if "PRIMARY" in row[TC.c.constraint_type.name]: pkeys.append(row[0]) if constraint_name is None: constraint_name = row[C.c.constraint_name.name] - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache @_db_plus_owner - def get_foreign_keys(self, connection, tablename, - dbname, owner, schema, **kw): + def get_foreign_keys( + self, connection, tablename, dbname, owner, schema, **kw + ): RR = ischema.ref_constraints - C = ischema.key_constraints.alias('C') - R = ischema.key_constraints.alias('R') + C = ischema.key_constraints.alias("C") + R = ischema.key_constraints.alias("R") # Foreign key constraints - s = sql.select([C.c.column_name, - R.c.table_schema, R.c.table_name, R.c.column_name, - RR.c.constraint_name, RR.c.match_option, - RR.c.update_rule, - RR.c.delete_rule], - sql.and_(C.c.table_name == tablename, - C.c.table_schema == owner, - RR.c.constraint_schema == C.c.table_schema, - C.c.constraint_name == RR.c.constraint_name, - R.c.constraint_name == - RR.c.unique_constraint_name, - R.c.constraint_schema == - RR.c.unique_constraint_schema, - C.c.ordinal_position == R.c.ordinal_position - ), - order_by=[RR.c.constraint_name, R.c.ordinal_position] - ) + s = sql.select( + [ + C.c.column_name, + R.c.table_schema, + R.c.table_name, + R.c.column_name, + RR.c.constraint_name, + RR.c.match_option, + RR.c.update_rule, + RR.c.delete_rule, + ], + sql.and_( + C.c.table_name == tablename, + C.c.table_schema == owner, + RR.c.constraint_schema == C.c.table_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + R.c.constraint_schema == RR.c.unique_constraint_schema, + C.c.ordinal_position == R.c.ordinal_position, + ), + order_by=[RR.c.constraint_name, R.c.ordinal_position], + ) # group rows by constraint ID, to handle multi-column FKs fkeys = [] @@ -2327,11 +2630,11 @@ class MSDialect(default.DefaultDialect): def fkey_rec(): return { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [] + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], } fkeys = util.defaultdict(fkey_rec) @@ -2340,17 +2643,18 @@ class MSDialect(default.DefaultDialect): scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r rec = fkeys[rfknm] - rec['name'] = rfknm - if not rec['referred_table']: - rec['referred_table'] = rtbl + rec["name"] = rfknm + if not rec["referred_table"]: + rec["referred_table"] = rtbl if schema is not None or owner != rschema: if dbname: rschema = dbname + "." + rschema - rec['referred_schema'] = rschema + rec["referred_schema"] = rschema - local_cols, remote_cols = \ - rec['constrained_columns'],\ - rec['referred_columns'] + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) local_cols.append(scol) remote_cols.append(rcol) |
