diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 415 | 
1 files changed, 223 insertions, 192 deletions
| diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 473f7df06..f4264b3d0 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -13,9 +13,9 @@  Auto Increment Behavior  ----------------------- -SQL Server provides so-called "auto incrementing" behavior using the ``IDENTITY`` -construct, which can be placed on an integer primary key.  SQLAlchemy -considers ``IDENTITY`` within its default "autoincrement" behavior, +SQL Server provides so-called "auto incrementing" behavior using the +``IDENTITY`` construct, which can be placed on an integer primary key. +SQLAlchemy considers ``IDENTITY`` within its default "autoincrement" behavior,  described at :paramref:`.Column.autoincrement`; this means  that by default, the first integer primary key column in a :class:`.Table`  will be considered to be the identity column and will generate DDL as such:: @@ -52,24 +52,25 @@ specify ``autoincrement=False`` on all integer primary key columns::      An INSERT statement which refers to an explicit value for such      a column is prohibited by SQL Server, however SQLAlchemy will detect this      and modify the ``IDENTITY_INSERT`` flag accordingly at statement execution -    time.  As this is not a high performing process, care should be taken to set -    the ``autoincrement`` flag appropriately for columns that will not actually -    require IDENTITY behavior. +    time.  As this is not a high performing process, care should be taken to +    set the ``autoincrement`` flag appropriately for columns that will not +    actually require IDENTITY behavior.  Controlling "Start" and "Increment"  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  Specific control over the parameters of the ``IDENTITY`` value is supported -using the :class:`.schema.Sequence` object.  While this object normally represents -an explicit "sequence" for supporting backends, on SQL Server it is re-purposed -to specify behavior regarding the identity column, including support -of the "start" and "increment" values:: +using the :class:`.schema.Sequence` object.  While this object normally +represents an explicit "sequence" for supporting backends, on SQL Server it is +re-purposed to specify behavior regarding the identity column, including +support of the "start" and "increment" values::      from sqlalchemy import Table, Integer, Sequence, Column      Table('test', metadata,             Column('id', Integer, -                  Sequence('blah', start=100, increment=10), primary_key=True), +                  Sequence('blah', start=100, increment=10), +                  primary_key=True),             Column('name', String(20))           ).create(some_engine) @@ -88,10 +89,10 @@ optional and will default to 1,1.  INSERT behavior  ^^^^^^^^^^^^^^^^ -Handling of the ``IDENTITY`` column at INSERT time involves two key techniques. -The most common is being able to fetch the "last inserted value" for a given -``IDENTITY`` column, a process which SQLAlchemy performs implicitly in many -cases, most importantly within the ORM. +Handling of the ``IDENTITY`` column at INSERT time involves two key +techniques. The most common is being able to fetch the "last inserted value" +for a given ``IDENTITY`` column, a process which SQLAlchemy performs +implicitly in many cases, most importantly within the ORM.  The process for fetching this value has several variants: @@ -106,9 +107,9 @@ The process for fetching this value has several variants:    ``implicit_returning=False``, either the ``scope_identity()`` function or    the ``@@identity`` variable is used; behavior varies by backend: -  * when using PyODBC, the phrase ``; select scope_identity()`` will be appended -    to the end of the INSERT statement; a second result set will be fetched -    in order to receive the value.  Given a table as:: +  * when using PyODBC, the phrase ``; select scope_identity()`` will be +    appended to the end of the INSERT statement; a second result set will be +    fetched in order to receive the value.  Given a table as::          t = Table('t', m, Column('id', Integer, primary_key=True),                  Column('x', Integer), @@ -121,17 +122,18 @@ The process for fetching this value has several variants:          INSERT INTO t (x) VALUES (?); select scope_identity()    * Other dialects such as pymssql will call upon -    ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT statement. -    If the flag ``use_scope_identity=False`` is passed to :func:`.create_engine`, -    the statement ``SELECT @@identity AS lastrowid`` is used instead. +    ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT +    statement. If the flag ``use_scope_identity=False`` is passed to +    :func:`.create_engine`, the statement ``SELECT @@identity AS lastrowid`` +    is used instead.  A table that contains an ``IDENTITY`` column will prohibit an INSERT statement  that refers to the identity column explicitly.  The SQLAlchemy dialect will  detect when an INSERT construct, created using a core :func:`.insert`  construct (not a plain string SQL), refers to the identity column, and -in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert statement -proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the execution. -Given this example:: +in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert +statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the +execution.  Given this example::      m = MetaData()      t = Table('t', m, Column('id', Integer, primary_key=True), @@ -250,7 +252,8 @@ To generate a clustered primary key use::  which will render the table, for example, as:: -  CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, PRIMARY KEY CLUSTERED (x, y)) +  CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, +                         PRIMARY KEY CLUSTERED (x, y))  Similarly, we can generate a clustered unique constraint using:: @@ -272,7 +275,8 @@ for :class:`.Index`.  INCLUDE  ^^^^^^^ -The ``mssql_include`` option renders INCLUDE(colname) for the given string names:: +The ``mssql_include`` option renders INCLUDE(colname) for the given string +names::      Index("my_index", table.c.x, mssql_include=['y']) @@ -364,13 +368,13 @@ import re  from ... import sql, schema as sa_schema, exc, util  from ...sql import compiler, expression, \ -                            util as sql_util, cast +    util as sql_util, cast  from ... import engine  from ...engine import reflection, default  from ... import types as sqltypes  from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ -                                FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ -                                VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR +    FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ +    VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR  from ...util import update_wrapper @@ -409,7 +413,7 @@ RESERVED_WORDS = set(       'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',       'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',       'writetext', -    ]) +     ])  class REAL(sqltypes.REAL): @@ -431,6 +435,7 @@ class TINYINT(sqltypes.Integer):  # not sure about other dialects).  class _MSDate(sqltypes.Date): +      def bind_processor(self, dialect):          def process(value):              if type(value) == datetime.date: @@ -447,15 +452,16 @@ class _MSDate(sqltypes.Date):                  return value.date()              elif isinstance(value, util.string_types):                  return datetime.date(*[ -                        int(x or 0) -                        for x in self._reg.match(value).groups() -                    ]) +                    int(x or 0) +                    for x in self._reg.match(value).groups() +                ])              else:                  return value          return process  class TIME(sqltypes.TIME): +      def __init__(self, precision=None, **kwargs):          self.precision = precision          super(TIME, self).__init__() @@ -466,7 +472,7 @@ class TIME(sqltypes.TIME):          def process(value):              if isinstance(value, datetime.datetime):                  value = datetime.datetime.combine( -                                self.__zero_date, value.time()) +                    self.__zero_date, value.time())              elif isinstance(value, datetime.time):                  value = datetime.datetime.combine(self.__zero_date, value)              return value @@ -480,8 +486,8 @@ class TIME(sqltypes.TIME):                  return value.time()              elif isinstance(value, util.string_types):                  return datetime.time(*[ -                        int(x or 0) -                        for x in self._reg.match(value).groups()]) +                    int(x or 0) +                    for x in self._reg.match(value).groups()])              else:                  return value          return process @@ -489,6 +495,7 @@ _MSTime = TIME  class _DateTimeBase(object): +      def bind_processor(self, dialect):          def process(value):              if type(value) == datetime.date: @@ -523,22 +530,21 @@ class DATETIMEOFFSET(sqltypes.TypeEngine):  class _StringType(object): +      """Base for MSSQL string types."""      def __init__(self, collation=None):          super(_StringType, self).__init__(collation=collation) - -  class NTEXT(sqltypes.UnicodeText): +      """MSSQL NTEXT type, for variable-length unicode text up to 2^30      characters."""      __visit_name__ = 'NTEXT' -  class IMAGE(sqltypes.LargeBinary):      __visit_name__ = 'IMAGE' @@ -620,6 +626,7 @@ ischema_names = {  class MSTypeCompiler(compiler.GenericTypeCompiler): +      def _extend(self, spec, type_, length=None):          """Extend a string-type declaration with standard SQL          COLLATE annotations. @@ -638,7 +645,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):              spec = spec + "(%s)" % length          return ' '.join([c for c in (spec, collation) -            if c is not None]) +                         if c is not None])      def visit_FLOAT(self, type_):          precision = getattr(type_, 'precision', None) @@ -717,9 +724,9 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):      def visit_VARBINARY(self, type_):          return self._extend( -                        "VARBINARY", -                        type_, -                        length=type_.length or 'max') +            "VARBINARY", +            type_, +            length=type_.length or 'max')      def visit_boolean(self, type_):          return self.visit_BIT(type_) @@ -762,20 +769,23 @@ class MSExecutionContext(default.DefaultExecutionContext):              if insert_has_sequence:                  self._enable_identity_insert = \ -                        seq_column.key in self.compiled_parameters[0] +                    seq_column.key in self.compiled_parameters[0]              else:                  self._enable_identity_insert = False              self._select_lastrowid = insert_has_sequence and \ -                                        not self.compiled.returning and \ -                                        not self._enable_identity_insert and \ -                                        not self.executemany +                not self.compiled.returning and \ +                not self._enable_identity_insert and \ +                not self.executemany              if self._enable_identity_insert: -                self.root_connection._cursor_execute(self.cursor, -                    self._opt_encode("SET IDENTITY_INSERT %s ON" % -                    self.dialect.identifier_preparer.format_table(tbl)), -                    (), self) +                self.root_connection._cursor_execute( +                    self.cursor, +                    self._opt_encode( +                        "SET IDENTITY_INSERT %s ON" % +                        self.dialect.identifier_preparer.format_table(tbl)), +                    (), +                    self)      def post_exec(self):          """Disable IDENTITY_INSERT if enabled.""" @@ -783,11 +793,14 @@ class MSExecutionContext(default.DefaultExecutionContext):          conn = self.root_connection          if self._select_lastrowid:              if self.dialect.use_scope_identity: -                conn._cursor_execute(self.cursor, +                conn._cursor_execute( +                    self.cursor,                      "SELECT scope_identity() AS lastrowid", (), self)              else:                  conn._cursor_execute(self.cursor, -                    "SELECT @@identity AS lastrowid", (), self) +                                     "SELECT @@identity AS lastrowid", +                                     (), +                                     self)              # fetchall() ensures the cursor is consumed without closing it              row = self.cursor.fetchall()[0]              self._lastrowid = int(row[0]) @@ -797,11 +810,14 @@ class MSExecutionContext(default.DefaultExecutionContext):              self._result_proxy = engine.FullyBufferedResultProxy(self)          if self._enable_identity_insert: -            conn._cursor_execute(self.cursor, -                        self._opt_encode("SET IDENTITY_INSERT %s OFF" % -                            self.dialect.identifier_preparer. -                                format_table(self.compiled.statement.table)), -                        (), self) +            conn._cursor_execute( +                self.cursor, +                self._opt_encode( +                    "SET IDENTITY_INSERT %s OFF" % +                    self.dialect.identifier_preparer. format_table( +                        self.compiled.statement.table)), +                (), +                self)      def get_lastrowid(self):          return self._lastrowid @@ -810,10 +826,10 @@ class MSExecutionContext(default.DefaultExecutionContext):          if self._enable_identity_insert:              try:                  self.cursor.execute( -                        self._opt_encode("SET IDENTITY_INSERT %s OFF" % -                            self.dialect.identifier_preparer.\ -                            format_table(self.compiled.statement.table)) -                        ) +                    self._opt_encode( +                        "SET IDENTITY_INSERT %s OFF" % +                        self.dialect.identifier_preparer. format_table( +                            self.compiled.statement.table)))              except:                  pass @@ -830,11 +846,11 @@ class MSSQLCompiler(compiler.SQLCompiler):      extract_map = util.update_copy(          compiler.SQLCompiler.extract_map,          { -        'doy': 'dayofyear', -        'dow': 'weekday', -        'milliseconds': 'millisecond', -        'microseconds': 'microsecond' -    }) +            'doy': 'dayofyear', +            'dow': 'weekday', +            'milliseconds': 'millisecond', +            'microseconds': 'microsecond' +        })      def __init__(self, *args, **kwargs):          self.tablealiases = {} @@ -854,8 +870,8 @@ class MSSQLCompiler(compiler.SQLCompiler):      def visit_concat_op_binary(self, binary, operator, **kw):          return "%s + %s" % \ -                (self.process(binary.left, **kw), -                self.process(binary.right, **kw)) +            (self.process(binary.left, **kw), +             self.process(binary.right, **kw))      def visit_true(self, expr, **kw):          return '1' @@ -865,8 +881,8 @@ class MSSQLCompiler(compiler.SQLCompiler):      def visit_match_op_binary(self, binary, operator, **kw):          return "CONTAINS (%s, %s)" % ( -                                        self.process(binary.left, **kw), -                                        self.process(binary.right, **kw)) +            self.process(binary.left, **kw), +            self.process(binary.right, **kw))      def get_select_precolumns(self, select):          """ MS-SQL puts TOP, it's version of LIMIT here """ @@ -902,20 +918,20 @@ class MSSQLCompiler(compiler.SQLCompiler):          """          if ( -                ( -                    not select._simple_int_limit and -                    select._limit_clause is not None -                ) or ( -                    select._offset_clause is not None and -                    not select._simple_int_offset or select._offset -                ) -            ) and not getattr(select, '_mssql_visit', None): +            ( +                not select._simple_int_limit and +                select._limit_clause is not None +            ) or ( +                select._offset_clause is not None and +                not select._simple_int_offset or select._offset +            ) +        ) and not getattr(select, '_mssql_visit', None):              # to use ROW_NUMBER(), an ORDER BY is required.              if not select._order_by_clause.clauses:                  raise exc.CompileError('MSSQL requires an order_by when ' -                                        'using an OFFSET or a non-simple ' -                                        'LIMIT clause') +                                       'using an OFFSET or a non-simple ' +                                       'LIMIT clause')              _order_by_clauses = select._order_by_clause.clauses              limit_clause = select._limit_clause @@ -923,20 +939,20 @@ class MSSQLCompiler(compiler.SQLCompiler):              select = select._generate()              select._mssql_visit = True              select = select.column( -                    sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) -                    .label("mssql_rn")).order_by(None).alias() +                sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) +                .label("mssql_rn")).order_by(None).alias()              mssql_rn = sql.column('mssql_rn')              limitselect = sql.select([c for c in select.c if -                                        c.key != 'mssql_rn']) +                                      c.key != 'mssql_rn'])              if offset_clause is not None:                  limitselect.append_whereclause(mssql_rn > offset_clause)                  if limit_clause is not None:                      limitselect.append_whereclause( -                            mssql_rn <= (limit_clause + offset_clause)) +                        mssql_rn <= (limit_clause + offset_clause))              else:                  limitselect.append_whereclause( -                            mssql_rn <= (limit_clause)) +                    mssql_rn <= (limit_clause))              return self.process(limitselect, iswrapper=True, **kwargs)          else:              return compiler.SQLCompiler.visit_select(self, select, **kwargs) @@ -968,10 +984,11 @@ class MSSQLCompiler(compiler.SQLCompiler):      def visit_extract(self, extract, **kw):          field = self.extract_map.get(extract.field, extract.field)          return 'DATEPART("%s", %s)' % \ -                        (field, self.process(extract.expr, **kw)) +            (field, self.process(extract.expr, **kw))      def visit_savepoint(self, savepoint_stmt): -        return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) +        return "SAVE TRANSACTION %s" % \ +            self.preparer.format_savepoint(savepoint_stmt)      def visit_rollback_to_savepoint(self, savepoint_stmt):          return ("ROLLBACK TRANSACTION %s" @@ -979,25 +996,26 @@ class MSSQLCompiler(compiler.SQLCompiler):      def visit_column(self, column, add_to_result_map=None, **kwargs):          if column.table is not None and \ -            (not self.isupdate and not self.isdelete) or self.is_subquery(): +                (not self.isupdate and not self.isdelete) or \ +                self.is_subquery():              # translate for schema-qualified table aliases              t = self._schema_aliased_table(column.table)              if t is not None:                  converted = expression._corresponding_column_or_error( -                                        t, column) +                    t, column)                  if add_to_result_map is not None:                      add_to_result_map( -                            column.name, -                            column.name, -                            (column, column.name, column.key), -                            column.type +                        column.name, +                        column.name, +                        (column, column.name, column.key), +                        column.type                      )                  return super(MSSQLCompiler, self).\ -                                visit_column(converted, **kwargs) +                    visit_column(converted, **kwargs)          return super(MSSQLCompiler, self).visit_column( -                        column, add_to_result_map=add_to_result_map, **kwargs) +            column, add_to_result_map=add_to_result_map, **kwargs)      def visit_binary(self, binary, **kwargs):          """Move bind parameters to the right-hand side of an operator, where @@ -1008,12 +1026,12 @@ class MSSQLCompiler(compiler.SQLCompiler):              isinstance(binary.left, expression.BindParameter)              and binary.operator == operator.eq              and not isinstance(binary.right, expression.BindParameter) -            ): +        ):              return self.process( -                                expression.BinaryExpression(binary.right, -                                                             binary.left, -                                                             binary.operator), -                                **kwargs) +                expression.BinaryExpression(binary.right, +                                            binary.left, +                                            binary.operator), +                **kwargs)          return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)      def returning_clause(self, stmt, returning_cols): @@ -1026,10 +1044,10 @@ class MSSQLCompiler(compiler.SQLCompiler):          adapter = sql_util.ClauseAdapter(target)          columns = [ -                self._label_select_column(None, adapter.traverse(c), -                                    True, False, {}) -                for c in expression._select_iterables(returning_cols) -            ] +            self._label_select_column(None, adapter.traverse(c), +                                      True, False, {}) +            for c in expression._select_iterables(returning_cols) +        ]          return 'OUTPUT ' + ', '.join(columns) @@ -1045,7 +1063,7 @@ class MSSQLCompiler(compiler.SQLCompiler):              return column.label(None)          else:              return super(MSSQLCompiler, self).\ -                            label_select_column(select, column, asfrom) +                label_select_column(select, column, asfrom)      def for_update_clause(self, select):          # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which @@ -1062,9 +1080,9 @@ class MSSQLCompiler(compiler.SQLCompiler):              return ""      def update_from_clause(self, update_stmt, -                                from_table, extra_froms, -                                from_hints, -                                **kw): +                           from_table, extra_froms, +                           from_hints, +                           **kw):          """Render the UPDATE..FROM clause specific to MSSQL.          In MSSQL, if the UPDATE statement involves an alias of the table to @@ -1073,12 +1091,13 @@ class MSSQLCompiler(compiler.SQLCompiler):          """          return "FROM " + ', '.join( -                    t._compiler_dispatch(self, asfrom=True, -                                    fromhints=from_hints, **kw) -                    for t in [from_table] + extra_froms) +            t._compiler_dispatch(self, asfrom=True, +                                 fromhints=from_hints, **kw) +            for t in [from_table] + extra_froms)  class MSSQLStrictCompiler(MSSQLCompiler): +      """A subclass of MSSQLCompiler which disables the usage of bind      parameters where not allowed natively by MS-SQL. @@ -1091,16 +1110,16 @@ class MSSQLStrictCompiler(MSSQLCompiler):      def visit_in_op_binary(self, binary, operator, **kw):          kw['literal_binds'] = True          return "%s IN %s" % ( -                                self.process(binary.left, **kw), -                                self.process(binary.right, **kw) -            ) +            self.process(binary.left, **kw), +            self.process(binary.right, **kw) +        )      def visit_notin_op_binary(self, binary, operator, **kw):          kw['literal_binds'] = True          return "%s NOT IN %s" % ( -                                self.process(binary.left, **kw), -                                self.process(binary.right, **kw) -            ) +            self.process(binary.left, **kw), +            self.process(binary.right, **kw) +        )      def render_literal_value(self, value, type_):          """ @@ -1119,10 +1138,11 @@ class MSSQLStrictCompiler(MSSQLCompiler):              return "'" + str(value) + "'"          else:              return super(MSSQLStrictCompiler, self).\ -                                render_literal_value(value, type_) +                render_literal_value(value, type_)  class MSDDLCompiler(compiler.DDLCompiler): +      def get_column_specification(self, column, **kwargs):          colspec = (self.preparer.format_column(column) + " "                     + self.dialect.type_compiler.process(column.type)) @@ -1136,17 +1156,19 @@ class MSDDLCompiler(compiler.DDLCompiler):          if column.table is None:              raise exc.CompileError( -                            "mssql requires Table-bound columns " -                            "in order to generate DDL") +                "mssql requires Table-bound columns " +                "in order to generate DDL") -        # install an IDENTITY Sequence if we either a sequence or an implicit IDENTITY column +        # install an IDENTITY Sequence if we either a sequence or an implicit +        # IDENTITY column          if isinstance(column.default, sa_schema.Sequence):              if column.default.start == 0:                  start = 0              else:                  start = column.default.start or 1 -            colspec += " IDENTITY(%s,%s)" % (start, column.default.increment or 1) +            colspec += " IDENTITY(%s,%s)" % (start, +                                             column.default.increment or 1)          elif column is column.table._autoincrement_column:              colspec += " IDENTITY(1,1)"          else: @@ -1169,21 +1191,24 @@ class MSDDLCompiler(compiler.DDLCompiler):              text += "CLUSTERED "          text += "INDEX %s ON %s (%s)" \ -                    % ( -                        self._prepared_index_name(index, -                                include_schema=include_schema), -                        preparer.format_table(index.table), -                       ', '.join( -                            self.sql_compiler.process(expr, -                                include_table=False, literal_binds=True) for -                                expr in index.expressions) -                        ) +            % ( +                self._prepared_index_name(index, +                                          include_schema=include_schema), +                preparer.format_table(index.table), +                ', '.join( +                    self.sql_compiler.process(expr, +                                              include_table=False, +                                              literal_binds=True) for +                    expr in index.expressions) +            )          # handle other included columns          if index.dialect_options['mssql']['include']:              inclusions = [index.table.c[col] -                            if isinstance(col, util.string_types) else col -                          for col in index.dialect_options['mssql']['include']] +                          if isinstance(col, util.string_types) else col +                          for col in +                          index.dialect_options['mssql']['include'] +                          ]              text += " INCLUDE (%s)" \                  % ', '.join([preparer.quote(c.name) @@ -1195,7 +1220,7 @@ class MSDDLCompiler(compiler.DDLCompiler):          return "\nDROP INDEX %s ON %s" % (              self._prepared_index_name(drop.element, include_schema=False),              self.preparer.format_table(drop.element.table) -            ) +        )      def visit_primary_key_constraint(self, constraint):          if len(constraint) == 0: @@ -1231,6 +1256,7 @@ class MSDDLCompiler(compiler.DDLCompiler):          text += self.define_constraint_deferrability(constraint)          return text +  class MSIdentifierPreparer(compiler.IdentifierPreparer):      reserved_words = RESERVED_WORDS @@ -1251,7 +1277,7 @@ def _db_plus_owner_listing(fn):      def wrap(dialect, connection, schema=None, **kw):          dbname, owner = _owner_plus_db(dialect, schema)          return _switch_db(dbname, connection, fn, dialect, connection, -                            dbname, owner, schema, **kw) +                          dbname, owner, schema, **kw)      return update_wrapper(wrap, fn) @@ -1259,7 +1285,7 @@ def _db_plus_owner(fn):      def wrap(dialect, connection, tablename, schema=None, **kw):          dbname, owner = _owner_plus_db(dialect, schema)          return _switch_db(dbname, connection, fn, dialect, connection, -                            tablename, dbname, owner, schema, **kw) +                          tablename, dbname, owner, schema, **kw)      return update_wrapper(wrap, fn) @@ -1334,7 +1360,7 @@ class MSDialect(default.DefaultDialect):          self.use_scope_identity = use_scope_identity          self.max_identifier_length = int(max_identifier_length or 0) or \ -                self.max_identifier_length +            self.max_identifier_length          super(MSDialect, self).__init__(**opts)      def do_savepoint(self, connection, name): @@ -1359,7 +1385,7 @@ class MSDialect(default.DefaultDialect):                  "is configured in the FreeTDS configuration." %                  ".".join(str(x) for x in self.server_version_info))          if self.server_version_info >= MS_2005_VERSION and \ -                    'implicit_returning' not in self.__dict__: +                'implicit_returning' not in self.__dict__:              self.implicit_returning = True          if self.server_version_info >= MS_2008_VERSION:              self.supports_multivalues_insert = True @@ -1395,8 +1421,8 @@ class MSDialect(default.DefaultDialect):      @reflection.cache      def get_schema_names(self, connection, **kw):          s = sql.select([ischema.schemata.c.schema_name], -            order_by=[ischema.schemata.c.schema_name] -        ) +                       order_by=[ischema.schemata.c.schema_name] +                       )          schema_names = [r[0] for r in connection.execute(s)]          return schema_names @@ -1405,10 +1431,10 @@ class MSDialect(default.DefaultDialect):      def get_table_names(self, connection, dbname, owner, schema, **kw):          tables = ischema.tables          s = sql.select([tables.c.table_name], -            sql.and_( -                tables.c.table_schema == owner, -                tables.c.table_type == 'BASE TABLE' -            ), +                       sql.and_( +            tables.c.table_schema == owner, +            tables.c.table_type == 'BASE TABLE' +        ),              order_by=[tables.c.table_name]          )          table_names = [r[0] for r in connection.execute(s)] @@ -1419,10 +1445,10 @@ class MSDialect(default.DefaultDialect):      def get_view_names(self, connection, dbname, owner, schema, **kw):          tables = ischema.tables          s = sql.select([tables.c.table_name], -            sql.and_( -                tables.c.table_schema == owner, -                tables.c.table_type == 'VIEW' -            ), +                       sql.and_( +            tables.c.table_schema == owner, +            tables.c.table_type == 'VIEW' +        ),              order_by=[tables.c.table_name]          )          view_names = [r[0] for r in connection.execute(s)] @@ -1438,22 +1464,22 @@ class MSDialect(default.DefaultDialect):          rp = connection.execute(              sql.text("select ind.index_id, ind.is_unique, ind.name " -                "from sys.indexes as ind join sys.tables as tab on " -                "ind.object_id=tab.object_id " -                "join sys.schemas as sch on sch.schema_id=tab.schema_id " -                "where tab.name = :tabname " -                "and sch.name=:schname " -                "and ind.is_primary_key=0", -                bindparams=[ -                    sql.bindparam('tabname', tablename, -                                    sqltypes.String(convert_unicode=True)), -                    sql.bindparam('schname', owner, -                                    sqltypes.String(convert_unicode=True)) -                ], -                typemap={ -                    'name': sqltypes.Unicode() -                } -            ) +                     "from sys.indexes as ind join sys.tables as tab on " +                     "ind.object_id=tab.object_id " +                     "join sys.schemas as sch on sch.schema_id=tab.schema_id " +                     "where tab.name = :tabname " +                     "and sch.name=:schname " +                     "and ind.is_primary_key=0", +                     bindparams=[ +                         sql.bindparam('tabname', tablename, +                                       sqltypes.String(convert_unicode=True)), +                         sql.bindparam('schname', owner, +                                       sqltypes.String(convert_unicode=True)) +                     ], +                     typemap={ +                         'name': sqltypes.Unicode() +                     } +                     )          )          indexes = {}          for row in rp: @@ -1473,15 +1499,15 @@ class MSDialect(default.DefaultDialect):                  "join sys.schemas as sch on sch.schema_id=tab.schema_id "                  "where tab.name=:tabname "                  "and sch.name=:schname", -                        bindparams=[ -                            sql.bindparam('tabname', tablename, -                                    sqltypes.String(convert_unicode=True)), -                            sql.bindparam('schname', owner, -                                    sqltypes.String(convert_unicode=True)) -                        ], -                        typemap={'name': sqltypes.Unicode()} -                        ), -            ) +                bindparams=[ +                    sql.bindparam('tabname', tablename, +                                  sqltypes.String(convert_unicode=True)), +                    sql.bindparam('schname', owner, +                                  sqltypes.String(convert_unicode=True)) +                ], +                typemap={'name': sqltypes.Unicode()} +            ), +        )          for row in rp:              if row['index_id'] in indexes:                  indexes[row['index_id']]['column_names'].append(row['name']) @@ -1490,7 +1516,8 @@ class MSDialect(default.DefaultDialect):      @reflection.cache      @_db_plus_owner -    def get_view_definition(self, connection, viewname, dbname, owner, schema, **kw): +    def get_view_definition(self, connection, viewname, +                            dbname, owner, schema, **kw):          rp = connection.execute(              sql.text(                  "select definition from sys.sql_modules as mod, " @@ -1502,9 +1529,9 @@ class MSDialect(default.DefaultDialect):                  "views.name=:viewname and sch.name=:schname",                  bindparams=[                      sql.bindparam('viewname', viewname, -                            sqltypes.String(convert_unicode=True)), +                                  sqltypes.String(convert_unicode=True)),                      sql.bindparam('schname', owner, -                            sqltypes.String(convert_unicode=True)) +                                  sqltypes.String(convert_unicode=True))                  ]              )          ) @@ -1524,7 +1551,7 @@ class MSDialect(default.DefaultDialect):          else:              whereclause = columns.c.table_name == tablename          s = sql.select([columns], whereclause, -                        order_by=[columns.c.ordinal_position]) +                       order_by=[columns.c.ordinal_position])          c = connection.execute(s)          cols = [] @@ -1594,7 +1621,7 @@ class MSDialect(default.DefaultDialect):                  ic = col_name                  colmap[col_name]['autoincrement'] = True                  colmap[col_name]['sequence'] = dict( -                                    name='%s_identity' % col_name) +                    name='%s_identity' % col_name)                  break          cursor.close() @@ -1603,7 +1630,7 @@ class MSDialect(default.DefaultDialect):              cursor = connection.execute(                  "select ident_seed('%s'), ident_incr('%s')"                  % (table_fullname, table_fullname) -                ) +            )              row = cursor.first()              if row is not None and row[0] is not None: @@ -1615,18 +1642,21 @@ class MSDialect(default.DefaultDialect):      @reflection.cache      @_db_plus_owner -    def get_pk_constraint(self, connection, tablename, dbname, owner, schema, **kw): +    def get_pk_constraint(self, connection, tablename, +                          dbname, owner, schema, **kw):          pkeys = []          TC = ischema.constraints          C = ischema.key_constraints.alias('C')          # Primary key constraints -        s = sql.select([C.c.column_name, TC.c.constraint_type, C.c.constraint_name], -            sql.and_(TC.c.constraint_name == C.c.constraint_name, -                    TC.c.table_schema == C.c.table_schema, -                     C.c.table_name == tablename, -                     C.c.table_schema == owner) -        ) +        s = sql.select([C.c.column_name, +                        TC.c.constraint_type, +                        C.c.constraint_name], +                       sql.and_(TC.c.constraint_name == C.c.constraint_name, +                                TC.c.table_schema == C.c.table_schema, +                                C.c.table_name == tablename, +                                C.c.table_schema == owner) +                       )          c = connection.execute(s)          constraint_name = None          for row in c: @@ -1638,7 +1668,8 @@ class MSDialect(default.DefaultDialect):      @reflection.cache      @_db_plus_owner -    def get_foreign_keys(self, connection, tablename, dbname, owner, schema, **kw): +    def get_foreign_keys(self, connection, tablename, +                         dbname, owner, schema, **kw):          RR = ischema.ref_constraints          C = ischema.key_constraints.alias('C')          R = ischema.key_constraints.alias('R') @@ -1653,11 +1684,11 @@ class MSDialect(default.DefaultDialect):                                  C.c.table_schema == owner,                                  C.c.constraint_name == RR.c.constraint_name,                                  R.c.constraint_name == -                                                RR.c.unique_constraint_name, +                                RR.c.unique_constraint_name,                                  C.c.ordinal_position == R.c.ordinal_position                                  ),                         order_by=[RR.c.constraint_name, R.c.ordinal_position] -        ) +                       )          # group rows by constraint ID, to handle multi-column FKs          fkeys = [] @@ -1687,8 +1718,8 @@ class MSDialect(default.DefaultDialect):                      rec['referred_schema'] = rschema              local_cols, remote_cols = \ -                                        rec['constrained_columns'],\ -                                        rec['referred_columns'] +                rec['constrained_columns'],\ +                rec['referred_columns']              local_cols.append(scol)              remote_cols.append(rcol) | 
