summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/mssql/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py415
1 files changed, 223 insertions, 192 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 473f7df06..f4264b3d0 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -13,9 +13,9 @@
Auto Increment Behavior
-----------------------
-SQL Server provides so-called "auto incrementing" behavior using the ``IDENTITY``
-construct, which can be placed on an integer primary key. SQLAlchemy
-considers ``IDENTITY`` within its default "autoincrement" behavior,
+SQL Server provides so-called "auto incrementing" behavior using the
+``IDENTITY`` construct, which can be placed on an integer primary key.
+SQLAlchemy considers ``IDENTITY`` within its default "autoincrement" behavior,
described at :paramref:`.Column.autoincrement`; this means
that by default, the first integer primary key column in a :class:`.Table`
will be considered to be the identity column and will generate DDL as such::
@@ -52,24 +52,25 @@ specify ``autoincrement=False`` on all integer primary key columns::
An INSERT statement which refers to an explicit value for such
a column is prohibited by SQL Server, however SQLAlchemy will detect this
and modify the ``IDENTITY_INSERT`` flag accordingly at statement execution
- time. As this is not a high performing process, care should be taken to set
- the ``autoincrement`` flag appropriately for columns that will not actually
- require IDENTITY behavior.
+ time. As this is not a high performing process, care should be taken to
+ set the ``autoincrement`` flag appropriately for columns that will not
+ actually require IDENTITY behavior.
Controlling "Start" and "Increment"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Specific control over the parameters of the ``IDENTITY`` value is supported
-using the :class:`.schema.Sequence` object. While this object normally represents
-an explicit "sequence" for supporting backends, on SQL Server it is re-purposed
-to specify behavior regarding the identity column, including support
-of the "start" and "increment" values::
+using the :class:`.schema.Sequence` object. While this object normally
+represents an explicit "sequence" for supporting backends, on SQL Server it is
+re-purposed to specify behavior regarding the identity column, including
+support of the "start" and "increment" values::
from sqlalchemy import Table, Integer, Sequence, Column
Table('test', metadata,
Column('id', Integer,
- Sequence('blah', start=100, increment=10), primary_key=True),
+ Sequence('blah', start=100, increment=10),
+ primary_key=True),
Column('name', String(20))
).create(some_engine)
@@ -88,10 +89,10 @@ optional and will default to 1,1.
INSERT behavior
^^^^^^^^^^^^^^^^
-Handling of the ``IDENTITY`` column at INSERT time involves two key techniques.
-The most common is being able to fetch the "last inserted value" for a given
-``IDENTITY`` column, a process which SQLAlchemy performs implicitly in many
-cases, most importantly within the ORM.
+Handling of the ``IDENTITY`` column at INSERT time involves two key
+techniques. The most common is being able to fetch the "last inserted value"
+for a given ``IDENTITY`` column, a process which SQLAlchemy performs
+implicitly in many cases, most importantly within the ORM.
The process for fetching this value has several variants:
@@ -106,9 +107,9 @@ The process for fetching this value has several variants:
``implicit_returning=False``, either the ``scope_identity()`` function or
the ``@@identity`` variable is used; behavior varies by backend:
- * when using PyODBC, the phrase ``; select scope_identity()`` will be appended
- to the end of the INSERT statement; a second result set will be fetched
- in order to receive the value. Given a table as::
+ * when using PyODBC, the phrase ``; select scope_identity()`` will be
+ appended to the end of the INSERT statement; a second result set will be
+ fetched in order to receive the value. Given a table as::
t = Table('t', m, Column('id', Integer, primary_key=True),
Column('x', Integer),
@@ -121,17 +122,18 @@ The process for fetching this value has several variants:
INSERT INTO t (x) VALUES (?); select scope_identity()
* Other dialects such as pymssql will call upon
- ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT statement.
- If the flag ``use_scope_identity=False`` is passed to :func:`.create_engine`,
- the statement ``SELECT @@identity AS lastrowid`` is used instead.
+ ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT
+ statement. If the flag ``use_scope_identity=False`` is passed to
+ :func:`.create_engine`, the statement ``SELECT @@identity AS lastrowid``
+ is used instead.
A table that contains an ``IDENTITY`` column will prohibit an INSERT statement
that refers to the identity column explicitly. The SQLAlchemy dialect will
detect when an INSERT construct, created using a core :func:`.insert`
construct (not a plain string SQL), refers to the identity column, and
-in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert statement
-proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the execution.
-Given this example::
+in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert
+statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the
+execution. Given this example::
m = MetaData()
t = Table('t', m, Column('id', Integer, primary_key=True),
@@ -250,7 +252,8 @@ To generate a clustered primary key use::
which will render the table, for example, as::
- CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, PRIMARY KEY CLUSTERED (x, y))
+ CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
+ PRIMARY KEY CLUSTERED (x, y))
Similarly, we can generate a clustered unique constraint using::
@@ -272,7 +275,8 @@ for :class:`.Index`.
INCLUDE
^^^^^^^
-The ``mssql_include`` option renders INCLUDE(colname) for the given string names::
+The ``mssql_include`` option renders INCLUDE(colname) for the given string
+names::
Index("my_index", table.c.x, mssql_include=['y'])
@@ -364,13 +368,13 @@ import re
from ... import sql, schema as sa_schema, exc, util
from ...sql import compiler, expression, \
- util as sql_util, cast
+ util as sql_util, cast
from ... import engine
from ...engine import reflection, default
from ... import types as sqltypes
from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \
- FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\
- VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR
+ FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\
+ VARBINARY, TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR
from ...util import update_wrapper
@@ -409,7 +413,7 @@ RESERVED_WORDS = set(
'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values',
'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with',
'writetext',
- ])
+ ])
class REAL(sqltypes.REAL):
@@ -431,6 +435,7 @@ class TINYINT(sqltypes.Integer):
# not sure about other dialects).
class _MSDate(sqltypes.Date):
+
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
@@ -447,15 +452,16 @@ class _MSDate(sqltypes.Date):
return value.date()
elif isinstance(value, util.string_types):
return datetime.date(*[
- int(x or 0)
- for x in self._reg.match(value).groups()
- ])
+ int(x or 0)
+ for x in self._reg.match(value).groups()
+ ])
else:
return value
return process
class TIME(sqltypes.TIME):
+
def __init__(self, precision=None, **kwargs):
self.precision = precision
super(TIME, self).__init__()
@@ -466,7 +472,7 @@ class TIME(sqltypes.TIME):
def process(value):
if isinstance(value, datetime.datetime):
value = datetime.datetime.combine(
- self.__zero_date, value.time())
+ self.__zero_date, value.time())
elif isinstance(value, datetime.time):
value = datetime.datetime.combine(self.__zero_date, value)
return value
@@ -480,8 +486,8 @@ class TIME(sqltypes.TIME):
return value.time()
elif isinstance(value, util.string_types):
return datetime.time(*[
- int(x or 0)
- for x in self._reg.match(value).groups()])
+ int(x or 0)
+ for x in self._reg.match(value).groups()])
else:
return value
return process
@@ -489,6 +495,7 @@ _MSTime = TIME
class _DateTimeBase(object):
+
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
@@ -523,22 +530,21 @@ class DATETIMEOFFSET(sqltypes.TypeEngine):
class _StringType(object):
+
"""Base for MSSQL string types."""
def __init__(self, collation=None):
super(_StringType, self).__init__(collation=collation)
-
-
class NTEXT(sqltypes.UnicodeText):
+
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
characters."""
__visit_name__ = 'NTEXT'
-
class IMAGE(sqltypes.LargeBinary):
__visit_name__ = 'IMAGE'
@@ -620,6 +626,7 @@ ischema_names = {
class MSTypeCompiler(compiler.GenericTypeCompiler):
+
def _extend(self, spec, type_, length=None):
"""Extend a string-type declaration with standard SQL
COLLATE annotations.
@@ -638,7 +645,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
spec = spec + "(%s)" % length
return ' '.join([c for c in (spec, collation)
- if c is not None])
+ if c is not None])
def visit_FLOAT(self, type_):
precision = getattr(type_, 'precision', None)
@@ -717,9 +724,9 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
def visit_VARBINARY(self, type_):
return self._extend(
- "VARBINARY",
- type_,
- length=type_.length or 'max')
+ "VARBINARY",
+ type_,
+ length=type_.length or 'max')
def visit_boolean(self, type_):
return self.visit_BIT(type_)
@@ -762,20 +769,23 @@ class MSExecutionContext(default.DefaultExecutionContext):
if insert_has_sequence:
self._enable_identity_insert = \
- seq_column.key in self.compiled_parameters[0]
+ seq_column.key in self.compiled_parameters[0]
else:
self._enable_identity_insert = False
self._select_lastrowid = insert_has_sequence and \
- not self.compiled.returning and \
- not self._enable_identity_insert and \
- not self.executemany
+ not self.compiled.returning and \
+ not self._enable_identity_insert and \
+ not self.executemany
if self._enable_identity_insert:
- self.root_connection._cursor_execute(self.cursor,
- self._opt_encode("SET IDENTITY_INSERT %s ON" %
- self.dialect.identifier_preparer.format_table(tbl)),
- (), self)
+ self.root_connection._cursor_execute(
+ self.cursor,
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s ON" %
+ self.dialect.identifier_preparer.format_table(tbl)),
+ (),
+ self)
def post_exec(self):
"""Disable IDENTITY_INSERT if enabled."""
@@ -783,11 +793,14 @@ class MSExecutionContext(default.DefaultExecutionContext):
conn = self.root_connection
if self._select_lastrowid:
if self.dialect.use_scope_identity:
- conn._cursor_execute(self.cursor,
+ conn._cursor_execute(
+ self.cursor,
"SELECT scope_identity() AS lastrowid", (), self)
else:
conn._cursor_execute(self.cursor,
- "SELECT @@identity AS lastrowid", (), self)
+ "SELECT @@identity AS lastrowid",
+ (),
+ self)
# fetchall() ensures the cursor is consumed without closing it
row = self.cursor.fetchall()[0]
self._lastrowid = int(row[0])
@@ -797,11 +810,14 @@ class MSExecutionContext(default.DefaultExecutionContext):
self._result_proxy = engine.FullyBufferedResultProxy(self)
if self._enable_identity_insert:
- conn._cursor_execute(self.cursor,
- self._opt_encode("SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer.
- format_table(self.compiled.statement.table)),
- (), self)
+ conn._cursor_execute(
+ self.cursor,
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s OFF" %
+ self.dialect.identifier_preparer. format_table(
+ self.compiled.statement.table)),
+ (),
+ self)
def get_lastrowid(self):
return self._lastrowid
@@ -810,10 +826,10 @@ class MSExecutionContext(default.DefaultExecutionContext):
if self._enable_identity_insert:
try:
self.cursor.execute(
- self._opt_encode("SET IDENTITY_INSERT %s OFF" %
- self.dialect.identifier_preparer.\
- format_table(self.compiled.statement.table))
- )
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s OFF" %
+ self.dialect.identifier_preparer. format_table(
+ self.compiled.statement.table)))
except:
pass
@@ -830,11 +846,11 @@ class MSSQLCompiler(compiler.SQLCompiler):
extract_map = util.update_copy(
compiler.SQLCompiler.extract_map,
{
- 'doy': 'dayofyear',
- 'dow': 'weekday',
- 'milliseconds': 'millisecond',
- 'microseconds': 'microsecond'
- })
+ 'doy': 'dayofyear',
+ 'dow': 'weekday',
+ 'milliseconds': 'millisecond',
+ 'microseconds': 'microsecond'
+ })
def __init__(self, *args, **kwargs):
self.tablealiases = {}
@@ -854,8 +870,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
def visit_concat_op_binary(self, binary, operator, **kw):
return "%s + %s" % \
- (self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ (self.process(binary.left, **kw),
+ self.process(binary.right, **kw))
def visit_true(self, expr, **kw):
return '1'
@@ -865,8 +881,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
def visit_match_op_binary(self, binary, operator, **kw):
return "CONTAINS (%s, %s)" % (
- self.process(binary.left, **kw),
- self.process(binary.right, **kw))
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw))
def get_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
@@ -902,20 +918,20 @@ class MSSQLCompiler(compiler.SQLCompiler):
"""
if (
- (
- not select._simple_int_limit and
- select._limit_clause is not None
- ) or (
- select._offset_clause is not None and
- not select._simple_int_offset or select._offset
- )
- ) and not getattr(select, '_mssql_visit', None):
+ (
+ not select._simple_int_limit and
+ select._limit_clause is not None
+ ) or (
+ select._offset_clause is not None and
+ not select._simple_int_offset or select._offset
+ )
+ ) and not getattr(select, '_mssql_visit', None):
# to use ROW_NUMBER(), an ORDER BY is required.
if not select._order_by_clause.clauses:
raise exc.CompileError('MSSQL requires an order_by when '
- 'using an OFFSET or a non-simple '
- 'LIMIT clause')
+ 'using an OFFSET or a non-simple '
+ 'LIMIT clause')
_order_by_clauses = select._order_by_clause.clauses
limit_clause = select._limit_clause
@@ -923,20 +939,20 @@ class MSSQLCompiler(compiler.SQLCompiler):
select = select._generate()
select._mssql_visit = True
select = select.column(
- sql.func.ROW_NUMBER().over(order_by=_order_by_clauses)
- .label("mssql_rn")).order_by(None).alias()
+ sql.func.ROW_NUMBER().over(order_by=_order_by_clauses)
+ .label("mssql_rn")).order_by(None).alias()
mssql_rn = sql.column('mssql_rn')
limitselect = sql.select([c for c in select.c if
- c.key != 'mssql_rn'])
+ c.key != 'mssql_rn'])
if offset_clause is not None:
limitselect.append_whereclause(mssql_rn > offset_clause)
if limit_clause is not None:
limitselect.append_whereclause(
- mssql_rn <= (limit_clause + offset_clause))
+ mssql_rn <= (limit_clause + offset_clause))
else:
limitselect.append_whereclause(
- mssql_rn <= (limit_clause))
+ mssql_rn <= (limit_clause))
return self.process(limitselect, iswrapper=True, **kwargs)
else:
return compiler.SQLCompiler.visit_select(self, select, **kwargs)
@@ -968,10 +984,11 @@ class MSSQLCompiler(compiler.SQLCompiler):
def visit_extract(self, extract, **kw):
field = self.extract_map.get(extract.field, extract.field)
return 'DATEPART("%s", %s)' % \
- (field, self.process(extract.expr, **kw))
+ (field, self.process(extract.expr, **kw))
def visit_savepoint(self, savepoint_stmt):
- return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+ return "SAVE TRANSACTION %s" % \
+ self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
return ("ROLLBACK TRANSACTION %s"
@@ -979,25 +996,26 @@ class MSSQLCompiler(compiler.SQLCompiler):
def visit_column(self, column, add_to_result_map=None, **kwargs):
if column.table is not None and \
- (not self.isupdate and not self.isdelete) or self.is_subquery():
+ (not self.isupdate and not self.isdelete) or \
+ self.is_subquery():
# translate for schema-qualified table aliases
t = self._schema_aliased_table(column.table)
if t is not None:
converted = expression._corresponding_column_or_error(
- t, column)
+ t, column)
if add_to_result_map is not None:
add_to_result_map(
- column.name,
- column.name,
- (column, column.name, column.key),
- column.type
+ column.name,
+ column.name,
+ (column, column.name, column.key),
+ column.type
)
return super(MSSQLCompiler, self).\
- visit_column(converted, **kwargs)
+ visit_column(converted, **kwargs)
return super(MSSQLCompiler, self).visit_column(
- column, add_to_result_map=add_to_result_map, **kwargs)
+ column, add_to_result_map=add_to_result_map, **kwargs)
def visit_binary(self, binary, **kwargs):
"""Move bind parameters to the right-hand side of an operator, where
@@ -1008,12 +1026,12 @@ class MSSQLCompiler(compiler.SQLCompiler):
isinstance(binary.left, expression.BindParameter)
and binary.operator == operator.eq
and not isinstance(binary.right, expression.BindParameter)
- ):
+ ):
return self.process(
- expression.BinaryExpression(binary.right,
- binary.left,
- binary.operator),
- **kwargs)
+ expression.BinaryExpression(binary.right,
+ binary.left,
+ binary.operator),
+ **kwargs)
return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
def returning_clause(self, stmt, returning_cols):
@@ -1026,10 +1044,10 @@ class MSSQLCompiler(compiler.SQLCompiler):
adapter = sql_util.ClauseAdapter(target)
columns = [
- self._label_select_column(None, adapter.traverse(c),
- True, False, {})
- for c in expression._select_iterables(returning_cols)
- ]
+ self._label_select_column(None, adapter.traverse(c),
+ True, False, {})
+ for c in expression._select_iterables(returning_cols)
+ ]
return 'OUTPUT ' + ', '.join(columns)
@@ -1045,7 +1063,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
return column.label(None)
else:
return super(MSSQLCompiler, self).\
- label_select_column(select, column, asfrom)
+ label_select_column(select, column, asfrom)
def for_update_clause(self, select):
# "FOR UPDATE" is only allowed on "DECLARE CURSOR" which
@@ -1062,9 +1080,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
return ""
def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ from_table, extra_froms,
+ from_hints,
+ **kw):
"""Render the UPDATE..FROM clause specific to MSSQL.
In MSSQL, if the UPDATE statement involves an alias of the table to
@@ -1073,12 +1091,13 @@ class MSSQLCompiler(compiler.SQLCompiler):
"""
return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in [from_table] + extra_froms)
+ t._compiler_dispatch(self, asfrom=True,
+ fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms)
class MSSQLStrictCompiler(MSSQLCompiler):
+
"""A subclass of MSSQLCompiler which disables the usage of bind
parameters where not allowed natively by MS-SQL.
@@ -1091,16 +1110,16 @@ class MSSQLStrictCompiler(MSSQLCompiler):
def visit_in_op_binary(self, binary, operator, **kw):
kw['literal_binds'] = True
return "%s IN %s" % (
- self.process(binary.left, **kw),
- self.process(binary.right, **kw)
- )
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)
+ )
def visit_notin_op_binary(self, binary, operator, **kw):
kw['literal_binds'] = True
return "%s NOT IN %s" % (
- self.process(binary.left, **kw),
- self.process(binary.right, **kw)
- )
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw)
+ )
def render_literal_value(self, value, type_):
"""
@@ -1119,10 +1138,11 @@ class MSSQLStrictCompiler(MSSQLCompiler):
return "'" + str(value) + "'"
else:
return super(MSSQLStrictCompiler, self).\
- render_literal_value(value, type_)
+ render_literal_value(value, type_)
class MSDDLCompiler(compiler.DDLCompiler):
+
def get_column_specification(self, column, **kwargs):
colspec = (self.preparer.format_column(column) + " "
+ self.dialect.type_compiler.process(column.type))
@@ -1136,17 +1156,19 @@ class MSDDLCompiler(compiler.DDLCompiler):
if column.table is None:
raise exc.CompileError(
- "mssql requires Table-bound columns "
- "in order to generate DDL")
+ "mssql requires Table-bound columns "
+ "in order to generate DDL")
- # install an IDENTITY Sequence if we either a sequence or an implicit IDENTITY column
+ # install an IDENTITY Sequence if we either a sequence or an implicit
+ # IDENTITY column
if isinstance(column.default, sa_schema.Sequence):
if column.default.start == 0:
start = 0
else:
start = column.default.start or 1
- colspec += " IDENTITY(%s,%s)" % (start, column.default.increment or 1)
+ colspec += " IDENTITY(%s,%s)" % (start,
+ column.default.increment or 1)
elif column is column.table._autoincrement_column:
colspec += " IDENTITY(1,1)"
else:
@@ -1169,21 +1191,24 @@ class MSDDLCompiler(compiler.DDLCompiler):
text += "CLUSTERED "
text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table),
- ', '.join(
- self.sql_compiler.process(expr,
- include_table=False, literal_binds=True) for
- expr in index.expressions)
- )
+ % (
+ self._prepared_index_name(index,
+ include_schema=include_schema),
+ preparer.format_table(index.table),
+ ', '.join(
+ self.sql_compiler.process(expr,
+ include_table=False,
+ literal_binds=True) for
+ expr in index.expressions)
+ )
# handle other included columns
if index.dialect_options['mssql']['include']:
inclusions = [index.table.c[col]
- if isinstance(col, util.string_types) else col
- for col in index.dialect_options['mssql']['include']]
+ if isinstance(col, util.string_types) else col
+ for col in
+ index.dialect_options['mssql']['include']
+ ]
text += " INCLUDE (%s)" \
% ', '.join([preparer.quote(c.name)
@@ -1195,7 +1220,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
return "\nDROP INDEX %s ON %s" % (
self._prepared_index_name(drop.element, include_schema=False),
self.preparer.format_table(drop.element.table)
- )
+ )
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
@@ -1231,6 +1256,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
text += self.define_constraint_deferrability(constraint)
return text
+
class MSIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
@@ -1251,7 +1277,7 @@ def _db_plus_owner_listing(fn):
def wrap(dialect, connection, schema=None, **kw):
dbname, owner = _owner_plus_db(dialect, schema)
return _switch_db(dbname, connection, fn, dialect, connection,
- dbname, owner, schema, **kw)
+ dbname, owner, schema, **kw)
return update_wrapper(wrap, fn)
@@ -1259,7 +1285,7 @@ def _db_plus_owner(fn):
def wrap(dialect, connection, tablename, schema=None, **kw):
dbname, owner = _owner_plus_db(dialect, schema)
return _switch_db(dbname, connection, fn, dialect, connection,
- tablename, dbname, owner, schema, **kw)
+ tablename, dbname, owner, schema, **kw)
return update_wrapper(wrap, fn)
@@ -1334,7 +1360,7 @@ class MSDialect(default.DefaultDialect):
self.use_scope_identity = use_scope_identity
self.max_identifier_length = int(max_identifier_length or 0) or \
- self.max_identifier_length
+ self.max_identifier_length
super(MSDialect, self).__init__(**opts)
def do_savepoint(self, connection, name):
@@ -1359,7 +1385,7 @@ class MSDialect(default.DefaultDialect):
"is configured in the FreeTDS configuration." %
".".join(str(x) for x in self.server_version_info))
if self.server_version_info >= MS_2005_VERSION and \
- 'implicit_returning' not in self.__dict__:
+ 'implicit_returning' not in self.__dict__:
self.implicit_returning = True
if self.server_version_info >= MS_2008_VERSION:
self.supports_multivalues_insert = True
@@ -1395,8 +1421,8 @@ class MSDialect(default.DefaultDialect):
@reflection.cache
def get_schema_names(self, connection, **kw):
s = sql.select([ischema.schemata.c.schema_name],
- order_by=[ischema.schemata.c.schema_name]
- )
+ order_by=[ischema.schemata.c.schema_name]
+ )
schema_names = [r[0] for r in connection.execute(s)]
return schema_names
@@ -1405,10 +1431,10 @@ class MSDialect(default.DefaultDialect):
def get_table_names(self, connection, dbname, owner, schema, **kw):
tables = ischema.tables
s = sql.select([tables.c.table_name],
- sql.and_(
- tables.c.table_schema == owner,
- tables.c.table_type == 'BASE TABLE'
- ),
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == 'BASE TABLE'
+ ),
order_by=[tables.c.table_name]
)
table_names = [r[0] for r in connection.execute(s)]
@@ -1419,10 +1445,10 @@ class MSDialect(default.DefaultDialect):
def get_view_names(self, connection, dbname, owner, schema, **kw):
tables = ischema.tables
s = sql.select([tables.c.table_name],
- sql.and_(
- tables.c.table_schema == owner,
- tables.c.table_type == 'VIEW'
- ),
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == 'VIEW'
+ ),
order_by=[tables.c.table_name]
)
view_names = [r[0] for r in connection.execute(s)]
@@ -1438,22 +1464,22 @@ class MSDialect(default.DefaultDialect):
rp = connection.execute(
sql.text("select ind.index_id, ind.is_unique, ind.name "
- "from sys.indexes as ind join sys.tables as tab on "
- "ind.object_id=tab.object_id "
- "join sys.schemas as sch on sch.schema_id=tab.schema_id "
- "where tab.name = :tabname "
- "and sch.name=:schname "
- "and ind.is_primary_key=0",
- bindparams=[
- sql.bindparam('tabname', tablename,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
- ],
- typemap={
- 'name': sqltypes.Unicode()
- }
- )
+ "from sys.indexes as ind join sys.tables as tab on "
+ "ind.object_id=tab.object_id "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name = :tabname "
+ "and sch.name=:schname "
+ "and ind.is_primary_key=0",
+ bindparams=[
+ sql.bindparam('tabname', tablename,
+ sqltypes.String(convert_unicode=True)),
+ sql.bindparam('schname', owner,
+ sqltypes.String(convert_unicode=True))
+ ],
+ typemap={
+ 'name': sqltypes.Unicode()
+ }
+ )
)
indexes = {}
for row in rp:
@@ -1473,15 +1499,15 @@ class MSDialect(default.DefaultDialect):
"join sys.schemas as sch on sch.schema_id=tab.schema_id "
"where tab.name=:tabname "
"and sch.name=:schname",
- bindparams=[
- sql.bindparam('tabname', tablename,
- sqltypes.String(convert_unicode=True)),
- sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
- ],
- typemap={'name': sqltypes.Unicode()}
- ),
- )
+ bindparams=[
+ sql.bindparam('tabname', tablename,
+ sqltypes.String(convert_unicode=True)),
+ sql.bindparam('schname', owner,
+ sqltypes.String(convert_unicode=True))
+ ],
+ typemap={'name': sqltypes.Unicode()}
+ ),
+ )
for row in rp:
if row['index_id'] in indexes:
indexes[row['index_id']]['column_names'].append(row['name'])
@@ -1490,7 +1516,8 @@ class MSDialect(default.DefaultDialect):
@reflection.cache
@_db_plus_owner
- def get_view_definition(self, connection, viewname, dbname, owner, schema, **kw):
+ def get_view_definition(self, connection, viewname,
+ dbname, owner, schema, **kw):
rp = connection.execute(
sql.text(
"select definition from sys.sql_modules as mod, "
@@ -1502,9 +1529,9 @@ class MSDialect(default.DefaultDialect):
"views.name=:viewname and sch.name=:schname",
bindparams=[
sql.bindparam('viewname', viewname,
- sqltypes.String(convert_unicode=True)),
+ sqltypes.String(convert_unicode=True)),
sql.bindparam('schname', owner,
- sqltypes.String(convert_unicode=True))
+ sqltypes.String(convert_unicode=True))
]
)
)
@@ -1524,7 +1551,7 @@ class MSDialect(default.DefaultDialect):
else:
whereclause = columns.c.table_name == tablename
s = sql.select([columns], whereclause,
- order_by=[columns.c.ordinal_position])
+ order_by=[columns.c.ordinal_position])
c = connection.execute(s)
cols = []
@@ -1594,7 +1621,7 @@ class MSDialect(default.DefaultDialect):
ic = col_name
colmap[col_name]['autoincrement'] = True
colmap[col_name]['sequence'] = dict(
- name='%s_identity' % col_name)
+ name='%s_identity' % col_name)
break
cursor.close()
@@ -1603,7 +1630,7 @@ class MSDialect(default.DefaultDialect):
cursor = connection.execute(
"select ident_seed('%s'), ident_incr('%s')"
% (table_fullname, table_fullname)
- )
+ )
row = cursor.first()
if row is not None and row[0] is not None:
@@ -1615,18 +1642,21 @@ class MSDialect(default.DefaultDialect):
@reflection.cache
@_db_plus_owner
- def get_pk_constraint(self, connection, tablename, dbname, owner, schema, **kw):
+ def get_pk_constraint(self, connection, tablename,
+ dbname, owner, schema, **kw):
pkeys = []
TC = ischema.constraints
C = ischema.key_constraints.alias('C')
# Primary key constraints
- s = sql.select([C.c.column_name, TC.c.constraint_type, C.c.constraint_name],
- sql.and_(TC.c.constraint_name == C.c.constraint_name,
- TC.c.table_schema == C.c.table_schema,
- C.c.table_name == tablename,
- C.c.table_schema == owner)
- )
+ s = sql.select([C.c.column_name,
+ TC.c.constraint_type,
+ C.c.constraint_name],
+ sql.and_(TC.c.constraint_name == C.c.constraint_name,
+ TC.c.table_schema == C.c.table_schema,
+ C.c.table_name == tablename,
+ C.c.table_schema == owner)
+ )
c = connection.execute(s)
constraint_name = None
for row in c:
@@ -1638,7 +1668,8 @@ class MSDialect(default.DefaultDialect):
@reflection.cache
@_db_plus_owner
- def get_foreign_keys(self, connection, tablename, dbname, owner, schema, **kw):
+ def get_foreign_keys(self, connection, tablename,
+ dbname, owner, schema, **kw):
RR = ischema.ref_constraints
C = ischema.key_constraints.alias('C')
R = ischema.key_constraints.alias('R')
@@ -1653,11 +1684,11 @@ class MSDialect(default.DefaultDialect):
C.c.table_schema == owner,
C.c.constraint_name == RR.c.constraint_name,
R.c.constraint_name ==
- RR.c.unique_constraint_name,
+ RR.c.unique_constraint_name,
C.c.ordinal_position == R.c.ordinal_position
),
order_by=[RR.c.constraint_name, R.c.ordinal_position]
- )
+ )
# group rows by constraint ID, to handle multi-column FKs
fkeys = []
@@ -1687,8 +1718,8 @@ class MSDialect(default.DefaultDialect):
rec['referred_schema'] = rschema
local_cols, remote_cols = \
- rec['constrained_columns'],\
- rec['referred_columns']
+ rec['constrained_columns'],\
+ rec['referred_columns']
local_cols.append(scol)
remote_cols.append(rcol)