diff options
| -rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 48 | ||||
| -rwxr-xr-x | test/dialect/mssql.py | 16 |
2 files changed, 56 insertions, 8 deletions
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 42743870a..1ff482cf5 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -443,6 +443,26 @@ class MSSQLDialect(default.DefaultDialect): raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi') dbapi = classmethod(dbapi) + def server_version_info(self, connection): + """A tuple of the database server version. + + Formats the remote server version as a tuple of version values, + e.g. ``(9, 0, 1399)``. If there are strings in the version number + they will be in the tuple too, so don't count on these all being + ``int`` values. + + This is a fast check that does not require a round trip. It is also + cached per-Connection. + """ + return connection.dialect._server_version_info(connection.connection) + server_version_info = base.connection_memoize( + ('mssql', 'server_version_info'))(server_version_info) + + def _server_version_info(self, dbapi_con): + """Return a tuple of the database's version number.""" + + raise NotImplementedError() + def create_connect_args(self, url): opts = url.translate_connect_args(username='user') opts.update(url.query) @@ -772,18 +792,18 @@ class MSSQLDialect_pyodbc(MSSQLDialect): if 'max_identifier_length' in keys: self.max_identifier_length = int(keys.pop('max_identifier_length')) if 'dsn' in keys: - connectors = ['dsn=%s' % keys['dsn']] + connectors = ['dsn=%s' % keys.pop('dsn')] else: connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'), - 'Server=%s' % keys['host'], - 'Database=%s' % keys['database'] ] + 'Server=%s' % keys.pop('host', ''), + 'Database=%s' % keys.pop('database', '') ] if 'port' in keys: - connectors.append('Port=%d' % int(keys['port'])) + connectors.append('Port=%d' % int(keys.pop('port'))) - user = keys.get("user") + user = keys.pop("user", None) if user: connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % keys.get("password", "")) + connectors.append("PWD=%s" % keys.pop('password', '')) else: connectors.append("TrustedConnection=Yes") @@ -791,7 +811,7 @@ class MSSQLDialect_pyodbc(MSSQLDialect): # textual data from your database encoding to your client encoding # This should obviously be set to 'No' if you query a cp1253 encoded # database from a latin1 client... - if 'odbc_autotranslate' in keys: + if 'odbc_autotranslate' in keys: connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) # Allow specification of partial ODBC connect string @@ -800,7 +820,7 @@ class MSSQLDialect_pyodbc(MSSQLDialect): if odbc_options[0]=="'" and odbc_options[-1]=="'": odbc_options=odbc_options[1:-1] connectors.append(odbc_options) - + connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) return [[";".join (connectors)], {}] def is_disconnect(self, e): @@ -828,6 +848,18 @@ class MSSQLDialect_pyodbc(MSSQLDialect): cursor.nextset() context._last_inserted_ids = [int(row[0])] + def _server_version_info(self, dbapi_con): + """Convert a pyodbc SQL_DBMS_VER string into a tuple.""" + + version = [] + r = re.compile('[.\-]') + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): + try: + version.append(int(n)) + except ValueError: + version.append(n) + return tuple(version) + class MSSQLDialect_adodbapi(MSSQLDialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index 02c583d5d..4708cc28c 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import * from sqlalchemy import exc from sqlalchemy.sql import table, column from sqlalchemy.databases import mssql +import sqlalchemy.engine.url as url from testlib import * @@ -362,5 +363,20 @@ class MatchTest(TestBase, AssertsCompiledSQL): self.assertEquals([1, 3, 5], [r.id for r in results]) +class ParseConnectTest(TestBase, AssertsCompiledSQL): + __only_on__ = 'mssql' + + def test_pyodbc_connect(self): + u = url.make_url('mssql://username:password@hostspec/database') + dialect = mssql.MSSQLDialect_pyodbc() + connection = dialect.create_connect_args(u) + self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) + + def test_pyodbc_extra_connect(self): + u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar') + dialect = mssql.MSSQLDialect_pyodbc() + connection = dialect.create_connect_args(u) + self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection) + if __name__ == "__main__": testenv.main() |
