summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/sybase/base.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-03-14 22:04:20 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2010-03-14 22:04:20 +0000
commit39fd3442e306f9c2981c347ab2487921f3948a61 (patch)
tree50868207def3fda8434be61660fae8944dde1229 /lib/sqlalchemy/dialects/sybase/base.py
parentd9af1828fbd79cc925abce98c9dd1d0b629e88a8 (diff)
downloadsqlalchemy-39fd3442e306f9c2981c347ab2487921f3948a61.tar.gz
- initial working version of sybase, with modifications to the transactional
model to accomodate Sybase's default mode of "no ddl in transactions". - identity insert not working yet. it seems the default here might be the opposite of that of MSSQL. - reflection will be a full rewrite - default DBAPI is python-sybase, well documented and nicely DBAPI compliant except for the bind parameter situation, where we have a straightforward workaround - full Sybase docs at: http://infocenter.sybase.com/help/index.jsp?topic=/com.sybase.help.ase_15.0/title.htm
Diffstat (limited to 'lib/sqlalchemy/dialects/sybase/base.py')
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py354
1 files changed, 133 insertions, 221 deletions
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
index 886a773d8..2e76a195c 100644
--- a/lib/sqlalchemy/dialects/sybase/base.py
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -5,39 +5,25 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Support for the Sybase iAnywhere database.
+"""Support for Sybase Adaptive Server Enterprise (ASE).
-This is not (yet) a full backend for Sybase ASE.
+Note that this dialect is no longer specific to Sybase iAnywhere.
+ASE is the primary support platform.
-This dialect is *not* ported to SQLAlchemy 0.6.
-
-This dialect is *not* tested on SQLAlchemy 0.6.
-
-
-Known issues / TODO:
-
- * Uses the mx.ODBC driver from egenix (version 2.1.0)
- * The current version of sqlalchemy.databases.sybase only supports
- mx.ODBC.Windows (other platforms such as mx.ODBC.unixODBC still need
- some development)
- * Support for pyodbc has been built in but is not yet complete (needs
- further development)
- * Results of running tests/alltests.py:
- Ran 934 tests in 287.032s
- FAILED (failures=3, errors=1)
- * Tested on 'Adaptive Server Anywhere 9' (version 9.0.1.1751)
"""
-import datetime, operator
-
-from sqlalchemy import util, sql, schema, exc
+import operator
from sqlalchemy.sql import compiler, expression
-from sqlalchemy.engine import default, base
+from sqlalchemy.engine import default, base, reflection
from sqlalchemy import types as sqltypes
from sqlalchemy.sql import operators as sql_operators
-from sqlalchemy import MetaData, Table, Column
-from sqlalchemy import String, Integer, SMALLINT, CHAR, ForeignKey
-from sqlalchemy.dialects.sybase.schema import *
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import util, sql, exc
+
+from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\
+ TEXT,DATE,DATETIME, FLOAT, NUMERIC,\
+ BIGINT,INT, INTEGER, SMALLINT, BINARY,\
+ VARBINARY
RESERVED_WORDS = set([
"add", "all", "alter", "and",
@@ -99,23 +85,33 @@ RESERVED_WORDS = set([
])
-class SybaseImage(sqltypes.LargeBinary):
- __visit_name__ = 'IMAGE'
+class UNICHAR(sqltypes.Unicode):
+ __visit_name__ = 'UNICHAR'
+
+class UNIVARCHAR(sqltypes.Unicode):
+ __visit_name__ = 'UNIVARCHAR'
+
+class UNITEXT(sqltypes.UnicodeText):
+ __visit_name__ = 'UNITEXT'
-class SybaseBit(sqltypes.TypeEngine):
+class TINYINT(sqltypes.Integer):
+ __visit_name__ = 'TINYINT'
+
+class BIT(sqltypes.TypeEngine):
__visit_name__ = 'BIT'
-class SybaseMoney(sqltypes.TypeEngine):
+class MONEY(sqltypes.TypeEngine):
__visit_name__ = "MONEY"
-class SybaseSmallMoney(SybaseMoney):
+class SMALLMONEY(sqltypes.TypeEngine):
__visit_name__ = "SMALLMONEY"
-class SybaseUniqueIdentifier(sqltypes.TypeEngine):
+class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
__visit_name__ = "UNIQUEIDENTIFIER"
-
-class SybaseBoolean(sqltypes.Boolean):
- pass
+
+class IMAGE(sqltypes.LargeBinary):
+ __visit_name__ = 'IMAGE'
+
class SybaseTypeCompiler(compiler.GenericTypeCompiler):
def visit_large_binary(self, type_):
@@ -123,6 +119,15 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler):
def visit_boolean(self, type_):
return self.visit_BIT(type_)
+
+ def visit_UNICHAR(self, type_):
+ return "UNICHAR(%d)" % type_.length
+
+ def visit_UNITEXT(self, type_):
+ return "UNITEXT"
+
+ def visit_TINYINT(self, type_):
+ return "TINYINT"
def visit_IMAGE(self, type_):
return "IMAGE"
@@ -140,56 +145,41 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler):
return "UNIQUEIDENTIFIER"
colspecs = {
- sqltypes.LargeBinary : SybaseImage,
- sqltypes.Boolean : SybaseBoolean,
}
ischema_names = {
- 'integer' : sqltypes.INTEGER,
- 'unsigned int' : sqltypes.Integer,
- 'unsigned smallint' : sqltypes.SmallInteger,
- 'unsigned bigint' : sqltypes.BigInteger,
- 'bigint': sqltypes.BIGINT,
- 'smallint' : sqltypes.SMALLINT,
- 'tinyint' : sqltypes.SmallInteger,
- 'varchar' : sqltypes.VARCHAR,
- 'long varchar' : sqltypes.Text,
- 'char' : sqltypes.CHAR,
- 'decimal' : sqltypes.DECIMAL,
- 'numeric' : sqltypes.NUMERIC,
- 'float' : sqltypes.FLOAT,
- 'double' : sqltypes.Numeric,
- 'binary' : sqltypes.LargeBinary,
- 'long binary' : sqltypes.LargeBinary,
- 'varbinary' : sqltypes.LargeBinary,
- 'bit': SybaseBit,
- 'image' : SybaseImage,
- 'timestamp': sqltypes.TIMESTAMP,
- 'money': SybaseMoney,
- 'smallmoney': SybaseSmallMoney,
- 'uniqueidentifier': SybaseUniqueIdentifier,
+ 'integer' : INTEGER,
+ 'unsigned int' : INTEGER, # TODO: unsigned flags
+ 'unsigned smallint' : SMALLINT, # TODO: unsigned flags
+ 'unsigned bigint' : BIGINT, # TODO: unsigned flags
+ 'bigint': BIGINT,
+ 'smallint' : SMALLINT,
+ 'tinyint' : TINYINT,
+ 'varchar' : VARCHAR,
+ 'long varchar' : TEXT, # TODO
+ 'char' : CHAR,
+ 'decimal' : DECIMAL,
+ 'numeric' : NUMERIC,
+ 'float' : FLOAT,
+ 'double' : NUMERIC, # TODO
+ 'binary' : BINARY,
+ 'varbinary' : VARBINARY,
+ 'bit': BIT,
+ 'image' : IMAGE,
+ 'timestamp': TIMESTAMP,
+ 'money': MONEY,
+ 'smallmoney': MONEY,
+ 'uniqueidentifier': UNIQUEIDENTIFIER,
}
class SybaseExecutionContext(default.DefaultExecutionContext):
-
def post_exec(self):
- if self.compiled.isinsert:
- table = self.compiled.statement.table
- # get the inserted values of the primary key
-
- # get any sequence IDs first (using @@identity)
+ if self.isinsert and not self.executemany:
self.cursor.execute("SELECT @@identity AS lastrowid")
- row = self.cursor.fetchone()
- lastrowid = int(row[0])
- if lastrowid > 0:
- # an IDENTITY was inserted, fetch it
- # FIXME: always insert in front ? This only works if the IDENTITY is the first column, no ?!
- if not hasattr(self, '_last_inserted_ids') or self._last_inserted_ids is None:
- self._last_inserted_ids = [lastrowid]
- else:
- self._last_inserted_ids = [lastrowid] + self._last_inserted_ids[1:]
+ row = self.cursor.fetchall()[0]
+ self._lastrowid = int(row[0])
class SybaseSQLCompiler(compiler.SQLCompiler):
@@ -204,12 +194,6 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
def visit_mod(self, binary, **kw):
return "MOD(%s, %s)" % (self.process(binary.left), self.process(binary.right))
- def bindparam_string(self, name):
- res = super(SybaseSQLCompiler, self).bindparam_string(name)
- if name.lower().startswith('literal'):
- res = 'STRING(%s)' % res
- return res
-
def get_select_precolumns(self, select):
s = select._distinct and "DISTINCT " or ""
if select._limit:
@@ -230,32 +214,22 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
# Limit in sybase is after the select keyword
return ""
- def visit_binary(self, binary):
+ def dont_visit_binary(self, binary):
"""Move bind parameters to the right-hand side of an operator, where possible."""
if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq:
return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator))
else:
return super(SybaseSQLCompiler, self).visit_binary(binary)
- def label_select_column(self, select, column, asfrom):
+ def dont_label_select_column(self, select, column, asfrom):
if isinstance(column, expression.Function):
return column.label(None)
else:
return super(SybaseSQLCompiler, self).label_select_column(select, column, asfrom)
- function_rewrites = {'current_date': 'getdate',
- }
- def visit_function(self, func):
- func.name = self.function_rewrites.get(func.name, func.name)
- res = super(SybaseSQLCompiler, self).visit_function(func)
- if func.name.lower() == 'getdate':
- # apply CAST operator
- # FIXME: what about _pyodbc ?
- cast = expression._Cast(func, SybaseDate_mxodbc)
- # infinite recursion
- # res = self.visit_cast(cast)
- res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
- return res
+# def visit_getdate_func(self, fn, **kw):
+ # TODO: need to cast? something ?
+# pass
def visit_extract(self, extract):
field = self.extract_map.get(extract.field, extract.field)
@@ -277,27 +251,38 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
class SybaseDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column) + " " + \
+ self.dialect.type_compiler.process(column.type)
- colspec = self.preparer.format_column(column)
+ if column.table is None:
+ raise exc.InvalidRequestError("The Sybase dialect requires Table-bound "\
+ "columns in order to generate DDL")
+ seq_col = column.table._autoincrement_column
- if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
- column.autoincrement and isinstance(column.type, sqltypes.Integer):
- if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
- column.sequence = schema.Sequence(column.name + '_seq')
+
- if hasattr(column, 'sequence'):
- column.table.has_sequence = column
- #colspec += " numeric(30,0) IDENTITY"
- colspec += " Integer IDENTITY"
+ # install a IDENTITY Sequence if we have an implicit IDENTITY column
+ if seq_col is column:
+ sequence = isinstance(column.default, sa_schema.Sequence) and column.default
+ if sequence:
+ start, increment = sequence.start or 1, sequence.increment or 1
+ else:
+ start, increment = 1, 1
+ if (start, increment) == (1, 1):
+ colspec += " IDENTITY"
+ else:
+ # TODO: need correct syntax for this
+ colspec += " IDENTITY(%s,%s)" % (start, increment)
else:
- colspec += " " + self.dialect.type_compiler.process(column.type)
-
- if not column.nullable:
- colspec += " NOT NULL"
+ if column.nullable is not None:
+ if not column.nullable or column.primary_key:
+ colspec += " NOT NULL"
+ else:
+ colspec += " NULL"
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
return colspec
@@ -324,120 +309,47 @@ class SybaseDialect(default.DefaultDialect):
ddl_compiler = SybaseDDLCompiler
preparer = SybaseIdentifierPreparer
- ported_sqla_06 = False
-
- schema_name = "dba"
-
- def __init__(self, **params):
- super(SybaseDialect, self).__init__(**params)
- self.text_as_varchar = False
-
- def last_inserted_ids(self):
- return self.context.last_inserted_ids
-
def _get_default_schema_name(self, connection):
- # TODO
- return self.schema_name
+ return connection.scalar(
+ text("SELECT user_name() as user_name", typemap={'user_name':Unicode})
+ )
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+ return self.table_names(connection, schema)
def table_names(self, connection, schema):
- """Ignore the schema and the charset for now."""
- s = sql.select([tables.c.table_name],
- sql.not_(tables.c.table_name.like("SYS%")) and
- tables.c.creator >= 100
- )
- rp = connection.execute(s)
- return [row[0] for row in rp.fetchall()]
+
+ result = connection.execute(
+ text("select sysobjects.name from sysobjects, sysusers "
+ "where sysobjects.uid=sysusers.uid and "
+ "sysusers.name=:schemaname and "
+ "sysobjects.type='U'",
+ bindparams=[
+ bindparam('schemaname', schema)
+ ])
+ )
+ return [r[0] for r in result]
def has_table(self, connection, tablename, schema=None):
- # FIXME: ignore schemas for sybase
- s = sql.select([tables.c.table_name], tables.c.table_name == tablename)
- return connection.execute(s).first() is not None
+ if schema is None:
+ schema = self.default_schema_name
+
+ result = connection.execute(
+ text("select sysobjects.name from sysobjects, sysusers "
+ "where sysobjects.uid=sysusers.uid and "
+ "sysobjects.name=:tablename and "
+ "sysusers.name=:schemaname and "
+ "sysobjects.type='U'",
+ bindparams=[
+ bindparam('tablename', tablename),
+ bindparam('schemaname', schema)
+ ])
+ )
+ return result.scalar() is not None
def reflecttable(self, connection, table, include_columns):
- # Get base columns
- if table.schema is not None:
- current_schema = table.schema
- else:
- current_schema = self.default_schema_name
-
- s = sql.select([columns, domains], tables.c.table_name==table.name, from_obj=[columns.join(tables).join(domains)], order_by=[columns.c.column_id])
-
- c = connection.execute(s)
- found_table = False
- # makes sure we append the columns in the correct order
- while True:
- row = c.fetchone()
- if row is None:
- break
- found_table = True
- (name, type, nullable, charlen, numericprec, numericscale, default, primary_key, max_identity, table_id, column_id) = (
- row[columns.c.column_name],
- row[domains.c.domain_name],
- row[columns.c.nulls] == 'Y',
- row[columns.c.width],
- row[domains.c.precision],
- row[columns.c.scale],
- row[columns.c.default],
- row[columns.c.pkey] == 'Y',
- row[columns.c.max_identity],
- row[tables.c.table_id],
- row[columns.c.column_id],
- )
- if include_columns and name not in include_columns:
- continue
-
- # FIXME: else problems with SybaseBinary(size)
- if numericscale == 0:
- numericscale = None
-
- args = []
- for a in (charlen, numericprec, numericscale):
- if a is not None:
- args.append(a)
- coltype = self.ischema_names.get(type, None)
- if coltype == SybaseString and charlen == -1:
- coltype = SybaseText()
- else:
- if coltype is None:
- util.warn("Did not recognize type '%s' of column '%s'" %
- (type, name))
- coltype = sqltypes.NULLTYPE
- coltype = coltype(*args)
- colargs = []
- if default is not None:
- colargs.append(schema.DefaultClause(sql.text(default)))
-
- # any sequences ?
- col = schema.Column(name, coltype, nullable=nullable, primary_key=primary_key, *colargs)
- if int(max_identity) > 0:
- col.sequence = schema.Sequence(name + '_identity')
- col.sequence.start = int(max_identity)
- col.sequence.increment = 1
-
- # append the column
- table.append_column(col)
-
- # any foreign key constraint for this table ?
- # note: no multi-column foreign keys are considered
- s = "select st1.table_name, sc1.column_name, st2.table_name, sc2.column_name from systable as st1 join sysfkcol on st1.table_id=sysfkcol.foreign_table_id join sysforeignkey join systable as st2 on sysforeignkey.primary_table_id = st2.table_id join syscolumn as sc1 on sysfkcol.foreign_column_id=sc1.column_id and sc1.table_id=st1.table_id join syscolumn as sc2 on sysfkcol.primary_column_id=sc2.column_id and sc2.table_id=st2.table_id where st1.table_name='%(table_name)s';" % { 'table_name' : table.name }
- c = connection.execute(s)
- foreignKeys = {}
- while True:
- row = c.fetchone()
- if row is None:
- break
- (foreign_table, foreign_column, primary_table, primary_column) = (
- row[0], row[1], row[2], row[3],
- )
- if not primary_table in foreignKeys.keys():
- foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
- else:
- foreignKeys[primary_table][0].append('%s'%(foreign_column))
- foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
- for primary_table in foreignKeys.iterkeys():
- #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
- table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1], link_to_name=True))
-
- if not found_table:
- raise exc.NoSuchTableError(table.name)
+ raise NotImplementedError()