diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
| commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
| tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/dialects/mssql/base.py | |
| parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
| download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz | |
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql/base.py')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 1448 |
1 files changed, 1448 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 000000000..cd031af40 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,1448 @@ +# mssql.py + +"""Support for the Microsoft SQL Server database. + +Driver +------ + +The MSSQL dialect will work with three different available drivers: + +* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded + driver. + +* *pymssql* - http://pymssql.sourceforge.net/ + +* *adodbapi* - http://adodbapi.sourceforge.net/ + +Drivers are loaded in the order listed above based on availability. + +If you need to load a specific driver pass ``module_name`` when +creating the engine:: + + engine = create_engine('mssql+module_name://dsn') + +``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and +``adodbapi``. + +Currently the pyodbc driver offers the greatest level of +compatibility. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``mssql://user:pass@host/dbname[?key=value&key=value...]``. + +If the database name is present, the tokens are converted to a +connection string with the specified values. If the database is not +present, then the host token is taken directly as the DSN name. + +Examples of pyodbc connection string URLs: + +* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``. + The connection string that is created will appear like:: + + dsn=mydsn;TrustedConnection=Yes + +* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named + ``mydsn`` passing in the ``UID`` and ``PWD`` information. The + connection string that is created will appear like:: + + dsn=mydsn;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects + using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` + information, plus the additional connection configuration option + ``LANGUAGE``. The connection string that is created will appear + like:: + + dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english + +* *mssql+pyodbc://user:pass@host/db* - connects using a connection string + dynamically created that would appear like:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection + string that is dynamically created, which also includes the port + information using the comma syntax. If your connection string + requires the port information to be passed as a ``port`` keyword + see the next example. This will create the following connection + string:: + + DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection + string that is dynamically created that includes the port + information as a separate ``port`` keyword. This will create the + following connection string:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 + +If you require a connection string that is outside the options +presented above, use the ``odbc_connect`` keyword to pass in a +urlencoded connection string. What gets passed in will be urldecoded +and passed directly. + +For example:: + + mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb + +would create the following connection string:: + + dsn=mydsn;Database=db + +Encoding your connection string can be easily accomplished through +the python shell. For example:: + + >>> import urllib + >>> urllib.quote_plus('dsn=mydsn;Database=db') + 'dsn%3Dmydsn%3BDatabase%3Ddb' + +Additional arguments which may be specified either as query string +arguments on the URL, or as keyword argument to +:func:`~sqlalchemy.create_engine()` are: + +* *query_timeout* - allows you to override the default query timeout. + Defaults to ``None``. This is only supported on pymssql. + +* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY + should be used in place of the non-scoped version @@IDENTITY. + Defaults to True. + +* *max_identifier_length* - allows you to se the maximum length of + identfiers supported by the database. Defaults to 128. For pymssql + the default is 30. + +* *schema_name* - use to set the schema name. Defaults to ``dbo``. + +Auto Increment Behavior +----------------------- + +``IDENTITY`` columns are supported by using SQLAlchemy +``schema.Sequence()`` objects. In other words:: + + Table('test', mss_engine, + Column('id', Integer, + Sequence('blah',100,10), primary_key=True), + Column('name', String(20)) + ).create() + +would yield:: + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +Note that the ``start`` and ``increment`` values for sequences are +optional and will default to 1,1. + +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for + ``INSERT`` s) + +* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on + ``INSERT`` + +Collation Support +----------------- + +MSSQL specific string types support a collation parameter that +creates a column-level specific collation for the column. The +collation parameter accepts a Windows Collation Name or a SQL +Collation Name. Supported types are MSChar, MSNChar, MSString, +MSNVarchar, MSText, and MSNText. For example:: + + Column('login', String(32, collation='Latin1_General_CI_AS')) + +will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is +supported directly through the ``TOP`` Transact SQL keyword:: + + select.limit + +will yield:: + + SELECT TOP n + +If using SQL Server 2005 or above, LIMIT with OFFSET +support is available through the ``ROW_NUMBER OVER`` construct. +For versions below 2005, LIMIT with OFFSET usage will fail. + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatibile with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always retrun the database +server version information (in this case SQL2005) and not the +compatibiility level information. Because of this, if running under +a backwards compatibility mode SQAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +Known Issues +------------ + +* No support for more than one ``IDENTITY`` column per table + +* pymssql has problems with binary and unicode data that this module + does **not** work around + +""" +import datetime, decimal, inspect, operator, sys, re +import itertools + +from sqlalchemy import sql, schema as sa_schema, exc, util +from sqlalchemy.sql import select, compiler, expression, \ + operators as sql_operators, \ + functions as sql_functions, util as sql_util +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import types as sqltypes +from decimal import Decimal as _python_Decimal +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ + FLOAT, TIMESTAMP, DATETIME, DATE + + +from sqlalchemy.dialects.mssql import information_schema as ischema + +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = set( + ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', + 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', + 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', + 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', + 'containstable', 'continue', 'convert', 'create', 'cross', 'current', + 'current_date', 'current_time', 'current_timestamp', 'current_user', + 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', + 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', + 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', + 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', + 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', + 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', + 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', + 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', + 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', + 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', + 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', + 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', + 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', + 'reconfigure', 'references', 'replication', 'restore', 'restrict', + 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', + 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', + 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', + 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', + 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', + 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', + 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', + 'writetext', + ]) + + +class _MSNumeric(sqltypes.Numeric): + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return _python_Decimal(str(value)) + else: + return value + return process + else: + def process(value): + return float(value) + return process + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, decimal.Decimal): + if value.adjusted() < 0: + result = "%s0.%s%s" % ( + (value < 0 and '-' or ''), + '0' * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value._int])) + + else: + if 'E' in str(value): + result = "%s%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int]), + "0" * (value.adjusted() - (len(value._int)-1))) + else: + if (len(value._int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1]), + "".join([str(s) for s in value._int][value.adjusted() + 1:])) + else: + result = "%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1])) + + return result + + else: + return value + + return process + +class REAL(sqltypes.Float): + """A type for ``real`` numbers.""" + + __visit_name__ = 'REAL' + + def __init__(self): + super(REAL, self).__init__(precision=24) + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, basestring): + return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(TIME, self).__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, basestring): + return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + + +class _DateTimeBase(object): + def bind_processor(self, dialect): + def process(value): + # TODO: why ? + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'SMALLDATETIME' + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'DATETIME2' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +# TODO: is this not an Interval ? +class DATETIMEOFFSET(sqltypes.TypeEngine): + __visit_name__ = 'DATETIMEOFFSET' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +class _StringType(object): + """Base for MSSQL string types.""" + + def __init__(self, collation=None): + self.collation = collation + + def __repr__(self): + attributes = inspect.getargspec(self.__init__)[0][1:] + attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) + + params = {} + for attr in attributes: + val = getattr(self, attr) + if val is not None and val is not False: + params[attr] = val + + return "%s(%s)" % (self.__class__.__name__, + ', '.join(['%s=%r' % (k, params[k]) for k in params])) + + +class TEXT(_StringType, sqltypes.TEXT): + """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" + + def __init__(self, *args, **kw): + """Construct a TEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.Text.__init__(self, *args, **kw) + +class NTEXT(_StringType, sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = 'NTEXT' + + def __init__(self, *args, **kwargs): + """Construct a NTEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kwargs.pop('collation', None) + _StringType.__init__(self, collation) + length = kwargs.pop('length', None) + sqltypes.UnicodeText.__init__(self, length, **kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a VARCHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.VARCHAR.__init__(self, *args, **kw) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MSSQL NVARCHAR type. + + For variable-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a NVARCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NVARCHAR.__init__(self, *args, **kw) + +class CHAR(_StringType, sqltypes.CHAR): + """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a CHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.CHAR.__init__(self, *args, **kw) + +class NCHAR(_StringType, sqltypes.NCHAR): + """MSSQL NCHAR type. + + For fixed-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct an NCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NCHAR.__init__(self, *args, **kw) + +class BINARY(sqltypes.Binary): + __visit_name__ = 'BINARY' + +class VARBINARY(sqltypes.Binary): + __visit_name__ = 'VARBINARY' + +class IMAGE(sqltypes.Binary): + __visit_name__ = 'IMAGE' + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class _MSBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = 'MONEY' + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = 'SMALLMONEY' + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = 'SQL_VARIANT' + +# old names. +MSNumeric = _MSNumeric +MSDateTime = _MSDateTime +MSDate = _MSDate +MSBoolean = _MSBoolean +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +colspecs = { + sqltypes.Numeric : _MSNumeric, + sqltypes.DateTime : _MSDateTime, + sqltypes.Date : _MSDate, + sqltypes.Time : TIME, + sqltypes.Boolean : _MSBoolean, +} + +ischema_names = { + 'int' : INTEGER, + 'bigint': BIGINT, + 'smallint' : SMALLINT, + 'tinyint' : TINYINT, + 'varchar' : VARCHAR, + 'nvarchar' : NVARCHAR, + 'char' : CHAR, + 'nchar' : NCHAR, + 'text' : TEXT, + 'ntext' : NTEXT, + 'decimal' : DECIMAL, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'datetime' : DATETIME, + 'datetime2' : DATETIME2, + 'datetimeoffset' : DATETIMEOFFSET, + 'date': DATE, + 'time': TIME, + 'smalldatetime' : SMALLDATETIME, + 'binary' : BINARY, + 'varbinary' : VARBINARY, + 'bit': BIT, + 'real' : REAL, + 'image' : IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': SMALLMONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, + 'sql_variant': SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, 'collation', None): + collation = 'COLLATE %s' % type_.collation + else: + collation = None + + if type_.length: + spec = spec + "(%d)" % type_.length + + return ' '.join([c for c in (spec, collation) + if c is not None]) + + def visit_FLOAT(self, type_): + precision = getattr(type_, 'precision', None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': precision} + + def visit_REAL(self, type_): + return "REAL" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_DATETIMEOFFSET(self, type_): + if type_.precision: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_TIME(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_DATETIME2(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_): + return "SMALLDATETIME" + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_unicode_text(self, type_): + return self.visit_NTEXT(type_) + + def visit_NTEXT(self, type_): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_): + return self._extend("VARCHAR", type_) + + def visit_CHAR(self, type_): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_): + return self._extend("NVARCHAR", type_) + + def visit_date(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_DATE(type_) + + def visit_time(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_TIME(type_) + + def visit_binary(self, type_): + if type_.length: + return self.visit_BINARY(type_) + else: + return self.visit_IMAGE(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%s)" % type_.length + else: + return "BINARY" + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%s)" % type_.length + else: + return "VARBINARY" + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return 'SMALLMONEY' + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_): + return 'SQL_VARIANT' + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _result_proxy = None + _lastrowid = None + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + self._select_lastrowid = insert_has_sequence and \ + not self.compiled.returning and \ + not self._enable_identity_insert and \ + not self.executemany + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl)) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid") + else: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchall()[0] # fetchall() ensures the cursor is consumed without closing it + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning: + self._result_proxy = base.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + + def get_lastrowid(self): + return self._lastrowid + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + except: + pass + + def get_result_proxy(self): + if self._result_proxy: + return self._result_proxy + else: + return base.ResultProxy(self) + +class MSSQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) + + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) + self.tablealiases = {} + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op(self, binary): + return "%s + %s" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def get_select_precolumns(self, select): + """ MS-SQL puts TOP, it's version of LIMIT here """ + if select._distinct or select._limit: + s = select._distinct and "DISTINCT " or "" + + if select._limit: + if not select._offset: + s += "TOP %s " % (select._limit,) + return s + return compiler.SQLCompiler.get_select_precolumns(self, select) + + def limit_clause(self, select): + # Limit in mssql is after the select keyword + return "" + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + + """ + if not getattr(select, '_mssql_visit', None) and select._offset: + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') + + _offset = select._offset + _limit = select._limit + select._mssql_visit = True + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() + + limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) + else: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + def visit_alias(self, alias, **kwargs): + # translate for schema-qualified table aliases + self.tablealiases[alias.original] = alias + kwargs['mssql_aliased'] = True + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_column(self, column, result_map=None, **kwargs): + if column.table is not None and \ + (not self.isupdate and not self.isdelete) or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = expression._corresponding_column_or_error(t, column) + + if result_map is not None: + result_map[column.name.lower()] = (column.name, (column, ), column.type) + + return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) + + return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ + and not isinstance(binary.right, expression._BindParamClause): + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) + else: + if (binary.operator is operator.eq or binary.operator is operator.ne) and ( + (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ + (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ + isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + def col_label(col): + adapted = adapter.traverse(c) + if isinstance(c, expression._Label): + return adapted.label(c.key) + else: + return self.label_select_column(None, adapted, asfrom=False) + + columns = [ + self.process( + col_label(c), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'OUTPUT ' + ', '.join(columns) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + if not column.table: + raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") + + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = getattr(column, 'sequence', None) + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(drop.element.table.name), + self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote) + ) + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + def _escape_identifier(self, value): + #TODO: determine MSSQL's escaping rules + return value + + def quote_schema(self, schema, force=True): + """Prepare a quoted table and schema name.""" + result = '.'.join([self.quote(x, force) for x in schema.split('.')]) + return result + +class MSDialect(default.DefaultDialect): + name = 'mssql' + supports_default_values = True + supports_empty_insert = False + execution_ctx_cls = MSExecutionContext + text_as_varchar = False + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + colspecs = colspecs + ischema_names = ischema_names + + supports_unicode_binds = True + postfetch_lastrowid = True + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + def __init__(self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", **opts): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.max_identifier_length = int(max_identifier_length or 0) or \ + self.max_identifier_length + super(MSDialect, self).__init__(**opts) + + def do_savepoint(self, connection, name): + util.warn("Savepoint support in mssql is experimental and may lead to data loss.") + connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + connection.execute("SAVE TRANSACTION %s" % name) + + def do_release_savepoint(self, connection, name): + pass + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__: + self.implicit_returning = True + + def get_default_schema_name(self, connection): + return self.default_schema_name + + def _get_default_schema_name(self, connection): + user_name = connection.scalar("SELECT user_name() as user_name;") + if user_name is not None: + # now, get the default schema + query = """ + SELECT default_schema_name FROM + sys.database_principals + WHERE name = ? + AND type = 'S' + """ + try: + default_schema_name = connection.scalar(query, [user_name]) + if default_schema_name is not None: + return default_schema_name + except: + pass + return self.schema_name + + def table_names(self, connection, schema): + s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema) + return [row[0] for row in connection.execute(s)] + + + def has_table(self, connection, tablename, schema=None): + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + ) + + c = connection.execute(s) + row = c.fetchone() + return row is not None + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select([ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name] + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'BASE TABLE' + ), + order_by=[tables.c.table_name] + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'VIEW' + ), + order_by=[tables.c.table_name] + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + # The cursor reports it is closed after executing the sp. + @reflection.cache + def get_indexes(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + full_tname = "%s.%s" % (current_schema, tablename) + indexes = [] + s = sql.text("exec sp_helpindex '%s'" % full_tname) + rp = connection.execute(s) + if rp.closed: + # did not work for this setup. + return [] + for row in rp: + if 'primary key' not in row['index_description']: + indexes.append({ + 'name' : row['index_name'], + 'column_names' : [c.strip() for c in row['index_keys'].split(',')], + 'unique': 'unique' in row['index_description'] + }) + return indexes + + @reflection.cache + def get_view_definition(self, connection, viewname, schema=None, **kw): + current_schema = schema or self.default_schema_name + views = ischema.views + s = sql.select([views.c.view_definition], + sql.and_( + views.c.table_schema == current_schema, + views.c.table_name == viewname + ), + ) + rp = connection.execute(s) + if rp: + view_def = rp.scalar() + return view_def + + @reflection.cache + def get_columns(self, connection, tablename, schema=None, **kw): + # Get base columns + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + order_by=[columns.c.ordinal_position]) + c = connection.execute(s) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type, nullable, charlen, numericprec, numericscale, default, collation) = ( + row[columns.c.column_name], + row[columns.c.data_type], + row[columns.c.is_nullable] == 'YES', + row[columns.c.character_maximum_length], + row[columns.c.numeric_precision], + row[columns.c.numeric_scale], + row[columns.c.column_default], + row[columns.c.collation_name] + ) + coltype = self.ischema_names.get(type, None) + + kwargs = {} + if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary): + kwargs['length'] = charlen + if collation: + kwargs['collation'] = collation + if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1): + kwargs.pop('length') + + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % (type, name)) + coltype = sqltypes.NULLTYPE + + if issubclass(coltype, sqltypes.Numeric) and coltype is not MSReal: + kwargs['scale'] = numericscale + kwargs['precision'] = numericprec + + coltype = coltype(**kwargs) + cdict = { + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'autoincrement':False, + } + cols.append(cdict) + # autoincrement and identity + colmap = {} + for col in cols: + colmap[col['name']] = col + # We also run an sp_columns to check for identity columns: + cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema)) + ic = None + while True: + row = cursor.fetchone() + if row is None: + break + (col_name, type_name) = row[3], row[5] + if type_name.endswith("identity") and col_name in colmap: + ic = col_name + colmap[col_name]['autoincrement'] = True + colmap[col_name]['sequence'] = dict( + name='%s_identity' % col_name) + break + cursor.close() + if ic is not None: + try: + # is this table_fullname reliable? + table_fullname = "%s.%s" % (current_schema, tablename) + cursor = connection.execute( + sql.text("select ident_seed(:seed), ident_incr(:incr)"), + {'seed':table_fullname, 'incr':table_fullname} + ) + row = cursor.fetchone() + cursor.close() + if not row is None: + colmap[ic]['sequence'].update({ + 'start' : int(row[0]), + 'increment' : int(row[1]) + }) + except: + # ignoring it, works just like before + pass + return cols + + @reflection.cache + def get_primary_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + pkeys = [] + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Primary key constraints + s = sql.select([C.c.column_name, TC.c.constraint_type], + sql.and_(TC.c.constraint_name == C.c.constraint_name, + C.c.table_name == tablename, + C.c.table_schema == current_schema) + ) + c = connection.execute(s) + for row in c: + if 'PRIMARY' in row[TC.c.constraint_type.name]: + pkeys.append(row[0]) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Foreign key constraints + s = sql.select([C.c.column_name, + R.c.table_schema, R.c.table_name, R.c.column_name, + RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule], + sql.and_(C.c.table_name == tablename, + C.c.table_schema == current_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + C.c.ordinal_position == R.c.ordinal_position + ), + order_by = [RR.c.constraint_name, R.c.ordinal_position]) + + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + fknm, scols, rcols = (None, [], []) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): + scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r + + rec = fkeys[rfknm] + rec['name'] = rfknm + if not rec['referred_table']: + rec['referred_table'] = rtbl + + if schema is not None or current_schema != rschema: + rec['referred_schema'] = rschema + + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + local_cols.append(scol) + remote_cols.append(rcol) + + return fkeys.values() + + +# fixme. I added this for the tests to run. -Randall +MSSQLDialect = MSDialect |
