diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-14 22:04:20 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2010-03-14 22:04:20 +0000 |
commit | 39fd3442e306f9c2981c347ab2487921f3948a61 (patch) | |
tree | 50868207def3fda8434be61660fae8944dde1229 /lib/sqlalchemy/dialects/sybase/base.py | |
parent | d9af1828fbd79cc925abce98c9dd1d0b629e88a8 (diff) | |
download | sqlalchemy-39fd3442e306f9c2981c347ab2487921f3948a61.tar.gz |
- initial working version of sybase, with modifications to the transactional
model to accomodate Sybase's default mode of "no ddl in transactions".
- identity insert not working yet. it seems the default here might be the
opposite of that of MSSQL.
- reflection will be a full rewrite
- default DBAPI is python-sybase, well documented and nicely DBAPI compliant
except for the bind parameter situation, where we have a straightforward workaround
- full Sybase docs at: http://infocenter.sybase.com/help/index.jsp?topic=/com.sybase.help.ase_15.0/title.htm
Diffstat (limited to 'lib/sqlalchemy/dialects/sybase/base.py')
-rw-r--r-- | lib/sqlalchemy/dialects/sybase/base.py | 354 |
1 files changed, 133 insertions, 221 deletions
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 886a773d8..2e76a195c 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -5,39 +5,25 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Support for the Sybase iAnywhere database. +"""Support for Sybase Adaptive Server Enterprise (ASE). -This is not (yet) a full backend for Sybase ASE. +Note that this dialect is no longer specific to Sybase iAnywhere. +ASE is the primary support platform. -This dialect is *not* ported to SQLAlchemy 0.6. - -This dialect is *not* tested on SQLAlchemy 0.6. - - -Known issues / TODO: - - * Uses the mx.ODBC driver from egenix (version 2.1.0) - * The current version of sqlalchemy.databases.sybase only supports - mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need - some development) - * Support for pyodbc has been built in but is not yet complete (needs - further development) - * Results of running tests/alltests.py: - Ran 934 tests in 287.032s - FAILED (failures=3, errors=1) - * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751) """ -import datetime, operator - -from sqlalchemy import util, sql, schema, exc +import operator from sqlalchemy.sql import compiler, expression -from sqlalchemy.engine import default, base +from sqlalchemy.engine import default, base, reflection from sqlalchemy import types as sqltypes from sqlalchemy.sql import operators as sql_operators -from sqlalchemy import MetaData, Table, Column -from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey -from sqlalchemy.dialects.sybase.schema import * +from sqlalchemy import schema as sa_schema +from sqlalchemy import util, sql, exc + +from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ + TEXT,DATE,DATETIME, FLOAT, NUMERIC,\ + BIGINT,INT, INTEGER, SMALLINT, BINARY,\ + VARBINARY RESERVED_WORDS = set([ "add", "all", "alter", "and", @@ -99,23 +85,33 @@ RESERVED_WORDS = set([ ]) -class SybaseImage(sqltypes.LargeBinary): - __visit_name__ = 'IMAGE' +class UNICHAR(sqltypes.Unicode): + __visit_name__ = 'UNICHAR' + +class UNIVARCHAR(sqltypes.Unicode): + __visit_name__ = 'UNIVARCHAR' + +class UNITEXT(sqltypes.UnicodeText): + __visit_name__ = 'UNITEXT' -class SybaseBit(sqltypes.TypeEngine): +class TINYINT(sqltypes.Integer): + __visit_name__ = 'TINYINT' + +class BIT(sqltypes.TypeEngine): __visit_name__ = 'BIT' -class SybaseMoney(sqltypes.TypeEngine): +class MONEY(sqltypes.TypeEngine): __visit_name__ = "MONEY" -class SybaseSmallMoney(SybaseMoney): +class SMALLMONEY(sqltypes.TypeEngine): __visit_name__ = "SMALLMONEY" -class SybaseUniqueIdentifier(sqltypes.TypeEngine): +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): __visit_name__ = "UNIQUEIDENTIFIER" - -class SybaseBoolean(sqltypes.Boolean): - pass + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = 'IMAGE' + class SybaseTypeCompiler(compiler.GenericTypeCompiler): def visit_large_binary(self, type_): @@ -123,6 +119,15 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler): def visit_boolean(self, type_): return self.visit_BIT(type_) + + def visit_UNICHAR(self, type_): + return "UNICHAR(%d)" % type_.length + + def visit_UNITEXT(self, type_): + return "UNITEXT" + + def visit_TINYINT(self, type_): + return "TINYINT" def visit_IMAGE(self, type_): return "IMAGE" @@ -140,56 +145,41 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler): return "UNIQUEIDENTIFIER" colspecs = { - sqltypes.LargeBinary : SybaseImage, - sqltypes.Boolean : SybaseBoolean, } ischema_names = { - 'integer' : sqltypes.INTEGER, - 'unsigned int' : sqltypes.Integer, - 'unsigned smallint' : sqltypes.SmallInteger, - 'unsigned bigint' : sqltypes.BigInteger, - 'bigint': sqltypes.BIGINT, - 'smallint' : sqltypes.SMALLINT, - 'tinyint' : sqltypes.SmallInteger, - 'varchar' : sqltypes.VARCHAR, - 'long varchar' : sqltypes.Text, - 'char' : sqltypes.CHAR, - 'decimal' : sqltypes.DECIMAL, - 'numeric' : sqltypes.NUMERIC, - 'float' : sqltypes.FLOAT, - 'double' : sqltypes.Numeric, - 'binary' : sqltypes.LargeBinary, - 'long binary' : sqltypes.LargeBinary, - 'varbinary' : sqltypes.LargeBinary, - 'bit': SybaseBit, - 'image' : SybaseImage, - 'timestamp': sqltypes.TIMESTAMP, - 'money': SybaseMoney, - 'smallmoney': SybaseSmallMoney, - 'uniqueidentifier': SybaseUniqueIdentifier, + 'integer' : INTEGER, + 'unsigned int' : INTEGER, # TODO: unsigned flags + 'unsigned smallint' : SMALLINT, # TODO: unsigned flags + 'unsigned bigint' : BIGINT, # TODO: unsigned flags + 'bigint': BIGINT, + 'smallint' : SMALLINT, + 'tinyint' : TINYINT, + 'varchar' : VARCHAR, + 'long varchar' : TEXT, # TODO + 'char' : CHAR, + 'decimal' : DECIMAL, + 'numeric' : NUMERIC, + 'float' : FLOAT, + 'double' : NUMERIC, # TODO + 'binary' : BINARY, + 'varbinary' : VARBINARY, + 'bit': BIT, + 'image' : IMAGE, + 'timestamp': TIMESTAMP, + 'money': MONEY, + 'smallmoney': MONEY, + 'uniqueidentifier': UNIQUEIDENTIFIER, } class SybaseExecutionContext(default.DefaultExecutionContext): - def post_exec(self): - if self.compiled.isinsert: - table = self.compiled.statement.table - # get the inserted values of the primary key - - # get any sequence IDs first (using @@identity) + if self.isinsert and not self.executemany: self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - lastrowid = int(row[0]) - if lastrowid > 0: - # an IDENTITY was inserted, fetch it - # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?! - if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None: - self._last_inserted_ids = [lastrowid] - else: - self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:] + row = self.cursor.fetchall()[0] + self._lastrowid = int(row[0]) class SybaseSQLCompiler(compiler.SQLCompiler): @@ -204,12 +194,6 @@ class SybaseSQLCompiler(compiler.SQLCompiler): def visit_mod(self, binary, **kw): return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right)) - def bindparam_string(self, name): - res = super(SybaseSQLCompiler, self).bindparam_string(name) - if name.lower().startswith('literal'): - res = 'STRING(%s)' % res - return res - def get_select_precolumns(self, select): s = select._distinct and "DISTINCT " or "" if select._limit: @@ -230,32 +214,22 @@ class SybaseSQLCompiler(compiler.SQLCompiler): # Limit in sybase is after the select keyword return "" - def visit_binary(self, binary): + def dont_visit_binary(self, binary): """Move bind parameters to the right-hand side of an operator, where possible.""" if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq: return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator)) else: return super(SybaseSQLCompiler, self).visit_binary(binary) - def label_select_column(self, select, column, asfrom): + def dont_label_select_column(self, select, column, asfrom): if isinstance(column, expression.Function): return column.label(None) else: return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom) - function_rewrites = {'current_date': 'getdate', - } - def visit_function(self, func): - func.name = self.function_rewrites.get(func.name, func.name) - res = super(SybaseSQLCompiler, self).visit_function(func) - if func.name.lower() == 'getdate': - # apply CAST operator - # FIXME: what about _pyodbc ? - cast = expression._Cast(func, SybaseDate_mxodbc) - # infinite recursion - # res = self.visit_cast(cast) - res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) - return res +# def visit_getdate_func(self, fn, **kw): + # TODO: need to cast? something ? +# pass def visit_extract(self, extract): field = self.extract_map.get(extract.field, extract.field) @@ -277,27 +251,38 @@ class SybaseSQLCompiler(compiler.SQLCompiler): class SybaseDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + \ + self.dialect.type_compiler.process(column.type) - colspec = self.preparer.format_column(column) + if column.table is None: + raise exc.InvalidRequestError("The Sybase dialect requires Table-bound "\ + "columns in order to generate DDL") + seq_col = column.table._autoincrement_column - if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ - column.autoincrement and isinstance(column.type, sqltypes.Integer): - if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional): - column.sequence = schema.Sequence(column.name + '_seq') + - if hasattr(column, 'sequence'): - column.table.has_sequence = column - #colspec += " numeric(30,0) IDENTITY" - colspec += " Integer IDENTITY" + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = isinstance(column.default, sa_schema.Sequence) and column.default + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + if (start, increment) == (1, 1): + colspec += " IDENTITY" + else: + # TODO: need correct syntax for this + colspec += " IDENTITY(%s,%s)" % (start, increment) else: - colspec += " " + self.dialect.type_compiler.process(column.type) - - if not column.nullable: - colspec += " NOT NULL" + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default return colspec @@ -324,120 +309,47 @@ class SybaseDialect(default.DefaultDialect): ddl_compiler = SybaseDDLCompiler preparer = SybaseIdentifierPreparer - ported_sqla_06 = False - - schema_name = "dba" - - def __init__(self, **params): - super(SybaseDialect, self).__init__(**params) - self.text_as_varchar = False - - def last_inserted_ids(self): - return self.context.last_inserted_ids - def _get_default_schema_name(self, connection): - # TODO - return self.schema_name + return connection.scalar( + text("SELECT user_name() as user_name", typemap={'user_name':Unicode}) + ) + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + return self.table_names(connection, schema) def table_names(self, connection, schema): - """Ignore the schema and the charset for now.""" - s = sql.select([tables.c.table_name], - sql.not_(tables.c.table_name.like("SYS%")) and - tables.c.creator >= 100 - ) - rp = connection.execute(s) - return [row[0] for row in rp.fetchall()] + + result = connection.execute( + text("select sysobjects.name from sysobjects, sysusers " + "where sysobjects.uid=sysusers.uid and " + "sysusers.name=:schemaname and " + "sysobjects.type='U'", + bindparams=[ + bindparam('schemaname', schema) + ]) + ) + return [r[0] for r in result] def has_table(self, connection, tablename, schema=None): - # FIXME: ignore schemas for sybase - s = sql.select([tables.c.table_name], tables.c.table_name == tablename) - return connection.execute(s).first() is not None + if schema is None: + schema = self.default_schema_name + + result = connection.execute( + text("select sysobjects.name from sysobjects, sysusers " + "where sysobjects.uid=sysusers.uid and " + "sysobjects.name=:tablename and " + "sysusers.name=:schemaname and " + "sysobjects.type='U'", + bindparams=[ + bindparam('tablename', tablename), + bindparam('schemaname', schema) + ]) + ) + return result.scalar() is not None def reflecttable(self, connection, table, include_columns): - # Get base columns - if table.schema is not None: - current_schema = table.schema - else: - current_schema = self.default_schema_name - - s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id]) - - c = connection.execute(s) - found_table = False - # makes sure we append the columns in the correct order - while True: - row = c.fetchone() - if row is None: - break - found_table = True - (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = ( - row[columns.c.column_name], - row[domains.c.domain_name], - row[columns.c.nulls] == 'Y', - row[columns.c.width], - row[domains.c.precision], - row[columns.c.scale], - row[columns.c.default], - row[columns.c.pkey] == 'Y', - row[columns.c.max_identity], - row[tables.c.table_id], - row[columns.c.column_id], - ) - if include_columns and name not in include_columns: - continue - - # FIXME: else problems with SybaseBinary(size) - if numericscale == 0: - numericscale = None - - args = [] - for a in (charlen, numericprec, numericscale): - if a is not None: - args.append(a) - coltype = self.ischema_names.get(type, None) - if coltype == SybaseString and charlen == -1: - coltype = SybaseText() - else: - if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (type, name)) - coltype = sqltypes.NULLTYPE - coltype = coltype(*args) - colargs = [] - if default is not None: - colargs.append(schema.DefaultClause(sql.text(default))) - - # any sequences ? - col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs) - if int(max_identity) > 0: - col.sequence = schema.Sequence(name + '_identity') - col.sequence.start = int(max_identity) - col.sequence.increment = 1 - - # append the column - table.append_column(col) - - # any foreign key constraint for this table ? - # note: no multi-column foreign keys are considered - s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name } - c = connection.execute(s) - foreignKeys = {} - while True: - row = c.fetchone() - if row is None: - break - (foreign_table, foreign_column, primary_table, primary_column) = ( - row[0], row[1], row[2], row[3], - ) - if not primary_table in foreignKeys.keys(): - foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] - else: - foreignKeys[primary_table][0].append('%s'%(foreign_column)) - foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) - for primary_table in foreignKeys.iterkeys(): - #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) - table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True)) - - if not found_table: - raise exc.NoSuchTableError(table.name) + raise NotImplementedError() |