diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/mssql')
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/__init__.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/adodbapi.py | 51 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/base.py | 1448 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/information_schema.py | 83 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/pymssql.py | 46 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mssql/pyodbc.py | 79 |
6 files changed, 1710 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 000000000..e3a829047 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.mssql import base, pyodbc, adodbapi, pymssql + +base.dialect = pyodbc.dialect
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py new file mode 100644 index 000000000..10b8b33b3 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -0,0 +1,51 @@ +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect +import sys + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + return process + + +class MSDialect_adodbapi(MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = True + driver = 'adodbapi' + + @classmethod + def import_dbapi(cls): + import adodbapi as module + return module + + colspecs = MSDialect.colspecs.copy() + colspecs[sqltypes.DateTime] = MSDateTime_adodbapi + + def create_connect_args(self, url): + keys = url.query + + connectors = ["Provider=SQLOLEDB"] + if 'port' in keys: + connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) + else: + connectors.append ("Data Source=%s" % keys.get("host")) + connectors.append ("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join (connectors)], {}] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) + +dialect = MSDialect_adodbapi diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 000000000..cd031af40 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,1448 @@ +# mssql.py + +"""Support for the Microsoft SQL Server database. + +Driver +------ + +The MSSQL dialect will work with three different available drivers: + +* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded + driver. + +* *pymssql* - http://pymssql.sourceforge.net/ + +* *adodbapi* - http://adodbapi.sourceforge.net/ + +Drivers are loaded in the order listed above based on availability. + +If you need to load a specific driver pass ``module_name`` when +creating the engine:: + + engine = create_engine('mssql+module_name://dsn') + +``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and +``adodbapi``. + +Currently the pyodbc driver offers the greatest level of +compatibility. + +Connecting +---------- + +Connecting with create_engine() uses the standard URL approach of +``mssql://user:pass@host/dbname[?key=value&key=value...]``. + +If the database name is present, the tokens are converted to a +connection string with the specified values. If the database is not +present, then the host token is taken directly as the DSN name. + +Examples of pyodbc connection string URLs: + +* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``. + The connection string that is created will appear like:: + + dsn=mydsn;TrustedConnection=Yes + +* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named + ``mydsn`` passing in the ``UID`` and ``PWD`` information. The + connection string that is created will appear like:: + + dsn=mydsn;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects + using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` + information, plus the additional connection configuration option + ``LANGUAGE``. The connection string that is created will appear + like:: + + dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english + +* *mssql+pyodbc://user:pass@host/db* - connects using a connection string + dynamically created that would appear like:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection + string that is dynamically created, which also includes the port + information using the comma syntax. If your connection string + requires the port information to be passed as a ``port`` keyword + see the next example. This will create the following connection + string:: + + DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass + +* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection + string that is dynamically created that includes the port + information as a separate ``port`` keyword. This will create the + following connection string:: + + DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123 + +If you require a connection string that is outside the options +presented above, use the ``odbc_connect`` keyword to pass in a +urlencoded connection string. What gets passed in will be urldecoded +and passed directly. + +For example:: + + mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb + +would create the following connection string:: + + dsn=mydsn;Database=db + +Encoding your connection string can be easily accomplished through +the python shell. For example:: + + >>> import urllib + >>> urllib.quote_plus('dsn=mydsn;Database=db') + 'dsn%3Dmydsn%3BDatabase%3Ddb' + +Additional arguments which may be specified either as query string +arguments on the URL, or as keyword argument to +:func:`~sqlalchemy.create_engine()` are: + +* *query_timeout* - allows you to override the default query timeout. + Defaults to ``None``. This is only supported on pymssql. + +* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY + should be used in place of the non-scoped version @@IDENTITY. + Defaults to True. + +* *max_identifier_length* - allows you to se the maximum length of + identfiers supported by the database. Defaults to 128. For pymssql + the default is 30. + +* *schema_name* - use to set the schema name. Defaults to ``dbo``. + +Auto Increment Behavior +----------------------- + +``IDENTITY`` columns are supported by using SQLAlchemy +``schema.Sequence()`` objects. In other words:: + + Table('test', mss_engine, + Column('id', Integer, + Sequence('blah',100,10), primary_key=True), + Column('name', String(20)) + ).create() + +would yield:: + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +Note that the ``start`` and ``increment`` values for sequences are +optional and will default to 1,1. + +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for + ``INSERT`` s) + +* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on + ``INSERT`` + +Collation Support +----------------- + +MSSQL specific string types support a collation parameter that +creates a column-level specific collation for the column. The +collation parameter accepts a Windows Collation Name or a SQL +Collation Name. Supported types are MSChar, MSNChar, MSString, +MSNVarchar, MSText, and MSNText. For example:: + + Column('login', String(32, collation='Latin1_General_CI_AS')) + +will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is +supported directly through the ``TOP`` Transact SQL keyword:: + + select.limit + +will yield:: + + SELECT TOP n + +If using SQL Server 2005 or above, LIMIT with OFFSET +support is available through the ``ROW_NUMBER OVER`` construct. +For versions below 2005, LIMIT with OFFSET usage will fail. + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatibile with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always retrun the database +server version information (in this case SQL2005) and not the +compatibiility level information. Because of this, if running under +a backwards compatibility mode SQAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +Known Issues +------------ + +* No support for more than one ``IDENTITY`` column per table + +* pymssql has problems with binary and unicode data that this module + does **not** work around + +""" +import datetime, decimal, inspect, operator, sys, re +import itertools + +from sqlalchemy import sql, schema as sa_schema, exc, util +from sqlalchemy.sql import select, compiler, expression, \ + operators as sql_operators, \ + functions as sql_functions, util as sql_util +from sqlalchemy.engine import default, base, reflection +from sqlalchemy import types as sqltypes +from decimal import Decimal as _python_Decimal +from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ + FLOAT, TIMESTAMP, DATETIME, DATE + + +from sqlalchemy.dialects.mssql import information_schema as ischema + +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = set( + ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', + 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', + 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', + 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', + 'containstable', 'continue', 'convert', 'create', 'cross', 'current', + 'current_date', 'current_time', 'current_timestamp', 'current_user', + 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', + 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', + 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', + 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', + 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', + 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', + 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', + 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', + 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', + 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', + 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', + 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', + 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', + 'reconfigure', 'references', 'replication', 'restore', 'restrict', + 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', + 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', + 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', + 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', + 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', + 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', + 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', + 'writetext', + ]) + + +class _MSNumeric(sqltypes.Numeric): + def result_processor(self, dialect): + if self.asdecimal: + def process(value): + if value is not None: + return _python_Decimal(str(value)) + else: + return value + return process + else: + def process(value): + return float(value) + return process + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, decimal.Decimal): + if value.adjusted() < 0: + result = "%s0.%s%s" % ( + (value < 0 and '-' or ''), + '0' * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value._int])) + + else: + if 'E' in str(value): + result = "%s%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int]), + "0" * (value.adjusted() - (len(value._int)-1))) + else: + if (len(value._int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1]), + "".join([str(s) for s in value._int][value.adjusted() + 1:])) + else: + result = "%s%s" % ( + (value < 0 and '-' or ''), + "".join([str(s) for s in value._int][0:value.adjusted() + 1])) + + return result + + else: + return value + + return process + +class REAL(sqltypes.Float): + """A type for ``real`` numbers.""" + + __visit_name__ = 'REAL' + + def __init__(self): + super(REAL, self).__init__(precision=24) + +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, basestring): + return datetime.date(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(TIME, self).__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif isinstance(value, datetime.time): + value = datetime.datetime.combine(self.__zero_date, value) + return value + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") + def result_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, basestring): + return datetime.time(*[int(x or 0) for x in self._reg.match(value).groups()]) + else: + return value + return process + + +class _DateTimeBase(object): + def bind_processor(self, dialect): + def process(value): + # TODO: why ? + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + return process + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'SMALLDATETIME' + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = 'DATETIME2' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +# TODO: is this not an Interval ? +class DATETIMEOFFSET(sqltypes.TypeEngine): + __visit_name__ = 'DATETIMEOFFSET' + + def __init__(self, precision=None, **kwargs): + self.precision = precision + + +class _StringType(object): + """Base for MSSQL string types.""" + + def __init__(self, collation=None): + self.collation = collation + + def __repr__(self): + attributes = inspect.getargspec(self.__init__)[0][1:] + attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) + + params = {} + for attr in attributes: + val = getattr(self, attr) + if val is not None and val is not False: + params[attr] = val + + return "%s(%s)" % (self.__class__.__name__, + ', '.join(['%s=%r' % (k, params[k]) for k in params])) + + +class TEXT(_StringType, sqltypes.TEXT): + """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" + + def __init__(self, *args, **kw): + """Construct a TEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.Text.__init__(self, *args, **kw) + +class NTEXT(_StringType, sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = 'NTEXT' + + def __init__(self, *args, **kwargs): + """Construct a NTEXT. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kwargs.pop('collation', None) + _StringType.__init__(self, collation) + length = kwargs.pop('length', None) + sqltypes.UnicodeText.__init__(self, length, **kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a VARCHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.VARCHAR.__init__(self, *args, **kw) + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MSSQL NVARCHAR type. + + For variable-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a NVARCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NVARCHAR.__init__(self, *args, **kw) + +class CHAR(_StringType, sqltypes.CHAR): + """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum + of 8,000 characters.""" + + def __init__(self, *args, **kw): + """Construct a CHAR. + + :param length: Optinal, maximum data length, in characters. + + :param convert_unicode: defaults to False. If True, convert + ``unicode`` data sent to the database to a ``str`` + bytestring, and convert bytestrings coming back from the + database into ``unicode``. + + Bytestrings are encoded using the dialect's + :attr:`~sqlalchemy.engine.base.Dialect.encoding`, which + defaults to `utf-8`. + + If False, may be overridden by + :attr:`sqlalchemy.engine.base.Dialect.convert_unicode`. + + :param assert_unicode: + + If None (the default), no assertion will take place unless + overridden by :attr:`sqlalchemy.engine.base.Dialect.assert_unicode`. + + If 'warn', will issue a runtime warning if a ``str`` + instance is used as a bind value. + + If true, will raise an :exc:`sqlalchemy.exc.InvalidRequestError`. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.CHAR.__init__(self, *args, **kw) + +class NCHAR(_StringType, sqltypes.NCHAR): + """MSSQL NCHAR type. + + For fixed-length unicode character data up to 4,000 characters.""" + + def __init__(self, *args, **kw): + """Construct an NCHAR. + + :param length: Optional, Maximum data length, in characters. + + :param collation: Optional, a column-level collation for this string + value. Accepts a Windows Collation Name or a SQL Collation Name. + + """ + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NCHAR.__init__(self, *args, **kw) + +class BINARY(sqltypes.Binary): + __visit_name__ = 'BINARY' + +class VARBINARY(sqltypes.Binary): + __visit_name__ = 'VARBINARY' + +class IMAGE(sqltypes.Binary): + __visit_name__ = 'IMAGE' + +class BIT(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + +class _MSBoolean(sqltypes.Boolean): + def result_processor(self, dialect): + def process(value): + if value is None: + return None + return value and True or False + return process + + def bind_processor(self, dialect): + def process(value): + if value is True: + return 1 + elif value is False: + return 0 + elif value is None: + return None + else: + return value and True or False + return process + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = 'MONEY' + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = 'SMALLMONEY' + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = 'SQL_VARIANT' + +# old names. +MSNumeric = _MSNumeric +MSDateTime = _MSDateTime +MSDate = _MSDate +MSBoolean = _MSBoolean +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +colspecs = { + sqltypes.Numeric : _MSNumeric, + sqltypes.DateTime : _MSDateTime, + sqltypes.Date : _MSDate, + sqltypes.Time : TIME, + sqltypes.Boolean : _MSBoolean, +} + +ischema_names = { + 'int' : INTEGER, + 'bigint': BIGINT, + 'smallint' : SMALLINT, + 'tinyint' : TINYINT, + 'varchar' : VARCHAR, + 'nvarchar' : NVARCHAR, + 'char' : CHAR, + 'nchar' : NCHAR, + 'text' : TEXT, + 'ntext' : NTEXT, + 'decimal' : DECIMAL, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'datetime' : DATETIME, + 'datetime2' : DATETIME2, + 'datetimeoffset' : DATETIMEOFFSET, + 'date': DATE, + 'time': TIME, + 'smalldatetime' : SMALLDATETIME, + 'binary' : BINARY, + 'varbinary' : VARBINARY, + 'bit': BIT, + 'real' : REAL, + 'image' : IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': SMALLMONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, + 'sql_variant': SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, 'collation', None): + collation = 'COLLATE %s' % type_.collation + else: + collation = None + + if type_.length: + spec = spec + "(%d)" % type_.length + + return ' '.join([c for c in (spec, collation) + if c is not None]) + + def visit_FLOAT(self, type_): + precision = getattr(type_, 'precision', None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': precision} + + def visit_REAL(self, type_): + return "REAL" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_DATETIMEOFFSET(self, type_): + if type_.precision: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_TIME(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_DATETIME2(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_): + return "SMALLDATETIME" + + def visit_unicode(self, type_): + return self.visit_NVARCHAR(type_) + + def visit_unicode_text(self, type_): + return self.visit_NTEXT(type_) + + def visit_NTEXT(self, type_): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_): + return self._extend("VARCHAR", type_) + + def visit_CHAR(self, type_): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_): + return self._extend("NVARCHAR", type_) + + def visit_date(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_DATE(type_) + + def visit_time(self, type_): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_) + else: + return self.visit_TIME(type_) + + def visit_binary(self, type_): + if type_.length: + return self.visit_BINARY(type_) + else: + return self.visit_IMAGE(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%s)" % type_.length + else: + return "BINARY" + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%s)" % type_.length + else: + return "VARBINARY" + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return 'SMALLMONEY' + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_): + return 'SQL_VARIANT' + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _result_proxy = None + _lastrowid = None + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] + else: + self._enable_identity_insert = False + + self._select_lastrowid = insert_has_sequence and \ + not self.compiled.returning and \ + not self._enable_identity_insert and \ + not self.executemany + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s ON" % + self.dialect.identifier_preparer.format_table(tbl)) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid") + else: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchall()[0] # fetchall() ensures the cursor is consumed without closing it + self._lastrowid = int(row[0]) + + if (self.isinsert or self.isupdate or self.isdelete) and self.compiled.returning: + self._result_proxy = base.FullyBufferedResultProxy(self) + + if self._enable_identity_insert: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + + def get_lastrowid(self): + return self._lastrowid + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) + except: + pass + + def get_result_proxy(self): + if self._result_proxy: + return self._result_proxy + else: + return base.ResultProxy(self) + +class MSSQLCompiler(compiler.SQLCompiler): + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) + + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) + self.tablealiases = {} + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op(self, binary): + return "%s + %s" % (self.process(binary.left), self.process(binary.right)) + + def visit_match_op(self, binary): + return "CONTAINS (%s, %s)" % (self.process(binary.left), self.process(binary.right)) + + def get_select_precolumns(self, select): + """ MS-SQL puts TOP, it's version of LIMIT here """ + if select._distinct or select._limit: + s = select._distinct and "DISTINCT " or "" + + if select._limit: + if not select._offset: + s += "TOP %s " % (select._limit,) + return s + return compiler.SQLCompiler.get_select_precolumns(self, select) + + def limit_clause(self, select): + # Limit in mssql is after the select keyword + return "" + + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + + """ + if not getattr(select, '_mssql_visit', None) and select._offset: + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') + + _offset = select._offset + _limit = select._limit + select._mssql_visit = True + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() + + limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) + else: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + def visit_alias(self, alias, **kwargs): + # translate for schema-qualified table aliases + self.tablealiases[alias.original] = alias + kwargs['mssql_aliased'] = True + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_column(self, column, result_map=None, **kwargs): + if column.table is not None and \ + (not self.isupdate and not self.isdelete) or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = expression._corresponding_column_or_error(t, column) + + if result_map is not None: + result_map[column.name.lower()] = (column.name, (column, ), column.type) + + return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) + + return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ + and not isinstance(binary.right, expression._BindParamClause): + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) + else: + if (binary.operator is operator.eq or binary.operator is operator.ne) and ( + (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ + (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ + isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def returning_clause(self, stmt, returning_cols): + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + def col_label(col): + adapted = adapter.traverse(c) + if isinstance(c, expression._Label): + return adapted.label(c.key) + else: + return self.label_select_column(None, adapted, asfrom=False) + + columns = [ + self.process( + col_label(c), + within_columns_clause=True, + result_map=self.result_map + ) + for c in expression._select_iterables(returning_cols) + ] + return 'OUTPUT ' + ', '.join(columns) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + if not column.table: + raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") + + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = getattr(column, 'sequence', None) + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(drop.element.table.name), + self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote) + ) + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + def _escape_identifier(self, value): + #TODO: determine MSSQL's escaping rules + return value + + def quote_schema(self, schema, force=True): + """Prepare a quoted table and schema name.""" + result = '.'.join([self.quote(x, force) for x in schema.split('.')]) + return result + +class MSDialect(default.DefaultDialect): + name = 'mssql' + supports_default_values = True + supports_empty_insert = False + execution_ctx_cls = MSExecutionContext + text_as_varchar = False + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + colspecs = colspecs + ischema_names = ischema_names + + supports_unicode_binds = True + postfetch_lastrowid = True + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + def __init__(self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", **opts): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.max_identifier_length = int(max_identifier_length or 0) or \ + self.max_identifier_length + super(MSDialect, self).__init__(**opts) + + def do_savepoint(self, connection, name): + util.warn("Savepoint support in mssql is experimental and may lead to data loss.") + connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + connection.execute("SAVE TRANSACTION %s" % name) + + def do_release_savepoint(self, connection, name): + pass + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__: + self.implicit_returning = True + + def get_default_schema_name(self, connection): + return self.default_schema_name + + def _get_default_schema_name(self, connection): + user_name = connection.scalar("SELECT user_name() as user_name;") + if user_name is not None: + # now, get the default schema + query = """ + SELECT default_schema_name FROM + sys.database_principals + WHERE name = ? + AND type = 'S' + """ + try: + default_schema_name = connection.scalar(query, [user_name]) + if default_schema_name is not None: + return default_schema_name + except: + pass + return self.schema_name + + def table_names(self, connection, schema): + s = select([ischema.tables.c.table_name], ischema.tables.c.table_schema==schema) + return [row[0] for row in connection.execute(s)] + + + def has_table(self, connection, tablename, schema=None): + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + ) + + c = connection.execute(s) + row = c.fetchone() + return row is not None + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select([ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name] + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'BASE TABLE' + ), + order_by=[tables.c.table_name] + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + current_schema = schema or self.default_schema_name + tables = ischema.tables + s = sql.select([tables.c.table_name], + sql.and_( + tables.c.table_schema == current_schema, + tables.c.table_type == 'VIEW' + ), + order_by=[tables.c.table_name] + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + # The cursor reports it is closed after executing the sp. + @reflection.cache + def get_indexes(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + full_tname = "%s.%s" % (current_schema, tablename) + indexes = [] + s = sql.text("exec sp_helpindex '%s'" % full_tname) + rp = connection.execute(s) + if rp.closed: + # did not work for this setup. + return [] + for row in rp: + if 'primary key' not in row['index_description']: + indexes.append({ + 'name' : row['index_name'], + 'column_names' : [c.strip() for c in row['index_keys'].split(',')], + 'unique': 'unique' in row['index_description'] + }) + return indexes + + @reflection.cache + def get_view_definition(self, connection, viewname, schema=None, **kw): + current_schema = schema or self.default_schema_name + views = ischema.views + s = sql.select([views.c.view_definition], + sql.and_( + views.c.table_schema == current_schema, + views.c.table_name == viewname + ), + ) + rp = connection.execute(s) + if rp: + view_def = rp.scalar() + return view_def + + @reflection.cache + def get_columns(self, connection, tablename, schema=None, **kw): + # Get base columns + current_schema = schema or self.default_schema_name + columns = ischema.columns + s = sql.select([columns], + current_schema + and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) + or columns.c.table_name==tablename, + order_by=[columns.c.ordinal_position]) + c = connection.execute(s) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + (name, type, nullable, charlen, numericprec, numericscale, default, collation) = ( + row[columns.c.column_name], + row[columns.c.data_type], + row[columns.c.is_nullable] == 'YES', + row[columns.c.character_maximum_length], + row[columns.c.numeric_precision], + row[columns.c.numeric_scale], + row[columns.c.column_default], + row[columns.c.collation_name] + ) + coltype = self.ischema_names.get(type, None) + + kwargs = {} + if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, MSNText, MSBinary, MSVarBinary, sqltypes.Binary): + kwargs['length'] = charlen + if collation: + kwargs['collation'] = collation + if coltype == MSText or (coltype in (MSString, MSNVarchar) and charlen == -1): + kwargs.pop('length') + + if coltype is None: + util.warn("Did not recognize type '%s' of column '%s'" % (type, name)) + coltype = sqltypes.NULLTYPE + + if issubclass(coltype, sqltypes.Numeric) and coltype is not MSReal: + kwargs['scale'] = numericscale + kwargs['precision'] = numericprec + + coltype = coltype(**kwargs) + cdict = { + 'name' : name, + 'type' : coltype, + 'nullable' : nullable, + 'default' : default, + 'autoincrement':False, + } + cols.append(cdict) + # autoincrement and identity + colmap = {} + for col in cols: + colmap[col['name']] = col + # We also run an sp_columns to check for identity columns: + cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema)) + ic = None + while True: + row = cursor.fetchone() + if row is None: + break + (col_name, type_name) = row[3], row[5] + if type_name.endswith("identity") and col_name in colmap: + ic = col_name + colmap[col_name]['autoincrement'] = True + colmap[col_name]['sequence'] = dict( + name='%s_identity' % col_name) + break + cursor.close() + if ic is not None: + try: + # is this table_fullname reliable? + table_fullname = "%s.%s" % (current_schema, tablename) + cursor = connection.execute( + sql.text("select ident_seed(:seed), ident_incr(:incr)"), + {'seed':table_fullname, 'incr':table_fullname} + ) + row = cursor.fetchone() + cursor.close() + if not row is None: + colmap[ic]['sequence'].update({ + 'start' : int(row[0]), + 'increment' : int(row[1]) + }) + except: + # ignoring it, works just like before + pass + return cols + + @reflection.cache + def get_primary_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + pkeys = [] + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Primary key constraints + s = sql.select([C.c.column_name, TC.c.constraint_type], + sql.and_(TC.c.constraint_name == C.c.constraint_name, + C.c.table_name == tablename, + C.c.table_schema == current_schema) + ) + c = connection.execute(s) + for row in c: + if 'PRIMARY' in row[TC.c.constraint_type.name]: + pkeys.append(row[0]) + return pkeys + + @reflection.cache + def get_foreign_keys(self, connection, tablename, schema=None, **kw): + current_schema = schema or self.default_schema_name + # Add constraints + RR = ischema.ref_constraints #information_schema.referential_constraints + TC = ischema.constraints #information_schema.table_constraints + C = ischema.key_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column + R = ischema.key_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column + + # Foreign key constraints + s = sql.select([C.c.column_name, + R.c.table_schema, R.c.table_name, R.c.column_name, + RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule], + sql.and_(C.c.table_name == tablename, + C.c.table_schema == current_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + C.c.ordinal_position == R.c.ordinal_position + ), + order_by = [RR.c.constraint_name, R.c.ordinal_position]) + + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + fknm, scols, rcols = (None, [], []) + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): + scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r + + rec = fkeys[rfknm] + rec['name'] = rfknm + if not rec['referred_table']: + rec['referred_table'] = rtbl + + if schema is not None or current_schema != rschema: + rec['referred_schema'] = rschema + + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + local_cols.append(scol) + remote_cols.append(rcol) + + return fkeys.values() + + +# fixme. I added this for the tests to run. -Randall +MSSQLDialect = MSDialect diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 000000000..bb6ff315a --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,83 @@ +from sqlalchemy import Table, MetaData, Column, ForeignKey +from sqlalchemy.types import String, Unicode, Integer, TypeDecorator + +ischema = MetaData() + +class CoerceUnicode(TypeDecorator): + impl = Unicode + + def process_bind_param(self, value, dialect): + if isinstance(value, str): + value = value.decode(dialect.encoding) + return value + +schemata = Table("SCHEMATA", ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA") + +tables = Table("TABLES", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", String, key="table_type"), + schema="INFORMATION_SCHEMA") + +columns = Table("COLUMNS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column("CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA") + +constraints = Table("TABLE_CONSTRAINTS", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", String, key="constraint_type"), + schema="INFORMATION_SCHEMA") + +column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA") + +key_constraints = Table("KEY_COLUMN_USAGE", ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA") + +ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, key="unique_constraint_catalog"), # TODO: is CATLOG misspelled ? + Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, key="unique_constraint_schema"), + Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA") + +views = Table("VIEWS", ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA") + diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 000000000..0961c2e76 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,46 @@ +from sqlalchemy.dialects.mssql.base import MSDialect +from sqlalchemy import types as sqltypes + + +class MSDialect_pymssql(MSDialect): + supports_sane_rowcount = False + max_identifier_length = 30 + driver = 'pymssql' + + @classmethod + def import_dbapi(cls): + import pymssql as module + # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal + module.Binary = lambda st: str(st) + return module + + def __init__(self, **params): + super(MSSQLDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + + # pymssql understands only ascii + if self.convert_unicode: + util.warn("pymssql does not support unicode") + self.encoding = params.get('encoding', 'ascii') + + + def create_connect_args(self, url): + if hasattr(self, 'query_timeout'): + # ick, globals ? we might want to move this.... + self.dbapi._mssql.set_query_timeout(self.query_timeout) + + keys = url.query + if keys.get('port'): + # pymssql expects port as host:port, not a separate arg + keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) + del keys['port'] + return [[], keys] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) + + def do_begin(self, connection): + pass + +dialect = MSDialect_pymssql
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 000000000..9a2a9e4e7 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,79 @@ +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy import types as sqltypes +import re +import sys + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same statement. + + Background on why "scope_identity()" is preferable to "@@identity": + http://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + http://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super(MSExecutionContext_pyodbc, self).pre_exec() + + # don't embed the scope_identity select into an "INSERT .. DEFAULT VALUES" + if self._select_lastrowid and \ + self.dialect.use_scope_identity and \ + len(self.parameters[0]): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed without closing it (FreeTDS particularly) + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error, e: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + + self._lastrowid = int(row[0]) + else: + super(MSExecutionContext_pyodbc, self).post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + execution_ctx_cls = MSExecutionContext_pyodbc + + pyodbc_driver_name = 'SQL Server' + + def __init__(self, description_encoding='latin-1', **params): + super(MSDialect_pyodbc, self).__init__(**params) + self.description_encoding = description_encoding + self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') + + def initialize(self, connection): + super(MSDialect_pyodbc, self).initialize(connection) + pyodbc = self.dbapi + + dbapi_con = connection.connection + + self._free_tds = re.match(r".*libtdsodbc.*\.so", dbapi_con.getinfo(pyodbc.SQL_DRIVER_NAME)) + + # the "Py2K only" part here is theoretical. + # have not tried pyodbc + python3.1 yet. + # Py2K + self.supports_unicode_statements = not self._free_tds + self.supports_unicode_binds = not self._free_tds + # end Py2K + +dialect = MSDialect_pyodbc |
