diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-02-25 22:44:52 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-02-25 22:44:52 +0000 |
| commit | 962c22c9eda7d2ab7dc0b41bd1c7a52cf0c9d008 (patch) | |
| tree | f0ab113c7947c80dfea42d4a1bef52217bf6ed96 /lib/sqlalchemy/databases | |
| parent | 8fa3becd5fac57bb898a0090bafaac377b60f070 (diff) | |
| download | sqlalchemy-962c22c9eda7d2ab7dc0b41bd1c7a52cf0c9d008.tar.gz | |
migrated (most) docstrings to pep-257 format, docstring generator using straight <pre> + trim() func
for now. applies most of [ticket:214], compliemnts of Lele Gaifax
Diffstat (limited to 'lib/sqlalchemy/databases')
| -rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/information_schema.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 98 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 56 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 131 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 115 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 57 |
7 files changed, 306 insertions, 170 deletions
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 5a25b12db..91a0869c6 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -330,17 +330,23 @@ class FBCompiler(ansisql.ANSICompiler): def visit_insert(self, insert): """Inserts are required to have the primary keys be explicitly present. - mapper will by default not put them in the insert statement to comply - with autoincrement fields that require they not be present. So, - put them all in for all primary key columns.""" + + mapper will by default not put them in the insert statement + to comply with autoincrement fields that require they not be + present. So, put them all in for all primary key columns. + """ + for c in insert.table.primary_key: if not self.parameters.has_key(c.key): self.parameters[c.key] = None return ansisql.ANSICompiler.visit_insert(self, insert) def visit_select_precolumns(self, select): - """Called when building a SELECT statement, position is just before column list - Firebird puts the limit and offset right after the select...""" + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + result = "" if select.limit: result += " FIRST %d " % select.limit @@ -351,7 +357,7 @@ class FBCompiler(ansisql.ANSICompiler): return result def limit_clause(self, select): - """Already taken care of in the visit_select_precolumns method.""" + """Already taken care of in the `visit_select_precolumns` method.""" return "" diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 5a7369ccd..54c47b6f4 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -85,6 +85,7 @@ class ISchema(object): def __init__(self, engine): self.engine = engine self.cache = {} + def __getattr__(self, name): if name not in self.cache: # This is a bit of a hack. diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 8cde7179f..254ea6013 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -1,36 +1,45 @@ # mssql.py -""" -notes: - supports the pymssq, adodbapi and pyodbc interfaces +"""MSSQL backend, thru either pymssq, adodbapi or pyodbc interfaces. + +* ``IDENTITY`` columns are supported by using SA ``schema.Sequence()`` + objects. In other words:: + + Table('test', mss_engine, + Column('id', Integer, Sequence('blah',100,10), primary_key=True), + Column('name', String(20)) + ).create() - IDENTITY columns are supported by using SA schema.Sequence() objects. In other words: - Table('test', mss_engine, - Column('id', Integer, Sequence('blah',100,10), primary_key=True), - Column('name', String(20)) - ).create() + would yield:: - would yield: - CREATE TABLE test ( - id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, - name VARCHAR(20) - ) - note that the start & increment values for sequences are optional and will default to 1,1 + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) + ) - support for SET IDENTITY_INSERT ON mode (automagic on / off for INSERTs) + Note that the start & increment values for sequences are optional + and will default to 1,1. - support for auto-fetching of @@IDENTITY on insert +* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for + ``INSERT``s) - select.limit implemented as SELECT TOP n +* Support for auto-fetching of ``@@IDENTITY`` on ``INSERT`` + +* ``select.limit`` implemented as ``SELECT TOP n`` Known issues / TODO: - no support for more than one IDENTITY column per table - no support for table reflection of IDENTITY columns with (seed,increment) values other than (1,1) - no support for GUID type columns (yet) - pymssql has problems with binary and unicode data that this module does NOT work around - adodbapi fails testtypes.py unit test on unicode data too -- issue with the test? +* No support for more than one ``IDENTITY`` column per table no + +* No support for table reflection of ``IDENTITY`` columns with + (seed,increment) values other than (1,1) + +* No support for ``GUID`` type columns (yet) + +* pymssql has problems with binary and unicode data that this module + does **not** work around adodbapi fails testtypes.py unit test on + unicode data too -- issue with the test? """ import sys, StringIO, string, types, re, datetime @@ -138,6 +147,7 @@ class MSNumeric(sqltypes.Numeric): class MSFloat(sqltypes.Float): def get_col_spec(self): return "FLOAT(%(precision)s)" % {'precision': self.precision} + def convert_bind_param(self, value, dialect): """By converting to string, we can use Decimal types round-trip.""" return str(value) @@ -197,14 +207,17 @@ class MSDate(sqltypes.Date): class MSText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" + class MSString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} - class MSNVarchar(MSString): - """NVARCHAR string, does unicode conversion if dialect.convert_encoding is true""" + """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. + """ + impl = sqltypes.Unicode + def get_col_spec(self): if self.length: return "NVARCHAR(%(length)s)" % {'length' : self.length} @@ -214,36 +227,45 @@ class MSNVarchar(MSString): class AdoMSNVarchar(MSNVarchar): def convert_bind_param(self, value, dialect): return value + def convert_result_value(self, value, dialect): return value class MSUnicode(sqltypes.Unicode): - """Unicode subclass, does unicode conversion in all cases, uses NVARCHAR impl""" + """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl.""" + impl = MSNVarchar class AdoMSUnicode(MSUnicode): impl = AdoMSNVarchar + def convert_bind_param(self, value, dialect): return value + def convert_result_value(self, value, dialect): return value class MSChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class MSNChar(sqltypes.NCHAR): def get_col_spec(self): return "NCHAR(%(length)s)" % {'length' : self.length} + class MSBinary(sqltypes.Binary): def get_col_spec(self): return "IMAGE" + class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BIT" + def convert_result_value(self, value, dialect): if value is None: return None return value and True or False + def convert_bind_param(self, value, dialect): if value is True: return 1 @@ -307,8 +329,12 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): super(MSSQLExecutionContext, self).__init__(dialect) def pre_exec(self, engine, proxy, compiled, parameters, **kwargs): - """ MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. - Activate it if the feature is turned on and needed. """ + """MS-SQL has a special mode for inserting non-NULL values + into IDENTITY columns. + + Activate it if the feature is turned on and needed. + """ + if getattr(compiled, "isinsert", False): tbl = compiled.statement.table if not hasattr(tbl, 'has_sequence'): @@ -337,7 +363,11 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs) def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """ + """Turn off the INDENTITY_INSERT mode if it's been activated, + and fetch recently inserted IDENTIFY values (works only for + one column). + """ + if getattr(compiled, "isinsert", False): if self.IINSERT: proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name) @@ -429,8 +459,7 @@ class MSSQLDialect(ansisql.ANSIDialect): c = self._pool.connect() c.supportsTransactions = 0 return c - - + def dbapi(self): return self.module @@ -535,7 +564,6 @@ class MSSQLDialect(ansisql.ANSIDialect): if 'PRIMARY' in row[TC.c.constraint_type.name]: table.primary_key.add(table.c[row[0]]) - # Foreign key constraints s = sql.select([C.c.column_name, R.c.table_schema, R.c.table_name, R.c.column_name, @@ -562,8 +590,6 @@ class MSSQLDialect(ansisql.ANSIDialect): if fknm and scols: table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) - - class PyMSSQLDialect(MSSQLDialect): def do_rollback(self, connection): @@ -578,7 +604,6 @@ class PyMSSQLDialect(MSSQLDialect): if hasattr(self, 'query_timeout'): dbmodule._mssql.set_query_timeout(self.query_timeout) return r - ## This code is leftover from the initial implementation, for reference ## def do_begin(self, connection): @@ -611,7 +636,6 @@ class PyMSSQLDialect(MSSQLDialect): ## r.query("begin tran") ## r.fetch_array() - class MSSQLCompiler(ansisql.ANSICompiler): def __init__(self, dialect, statement, parameters, **kwargs): super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) @@ -627,7 +651,6 @@ class MSSQLCompiler(ansisql.ANSICompiler): def limit_clause(self, select): # Limit in mssql is after the select keyword; MSsql has no support for offset return "" - def visit_table(self, table): # alias schema-qualified tables @@ -699,7 +722,6 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec - class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX " + index.table.name + "." + index.name) @@ -711,9 +733,11 @@ class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def __init__(self, dialect): super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + def _escape_identifier(self, value): #TODO: determin MSSQL's escapeing rules return value + def _fold_identifier_case(self, value): #TODO: determin MSSQL's case folding rules return value diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index c6bf2695f..1cb41cf76 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -25,21 +25,24 @@ def kw_colspec(self, spec): if self.zerofill: spec += ' ZEROFILL' return spec - + class MSNumeric(sqltypes.Numeric): def __init__(self, precision = 10, length = 2, **kw): self.unsigned = 'unsigned' in kw self.zerofill = 'zerofill' in kw super(MSNumeric, self).__init__(precision, length) + def get_col_spec(self): if self.precision is None: return kw_colspec(self, "NUMERIC") else: return kw_colspec(self, "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + class MSDecimal(MSNumeric): def get_col_spec(self): if self.precision is not None and self.length is not None: return kw_colspec(self, "DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + class MSDouble(MSNumeric): def __init__(self, precision=10, length=2, **kw): if (precision is None and length is not None) or (precision is not None and length is None): @@ -47,11 +50,13 @@ class MSDouble(MSNumeric): self.unsigned = 'unsigned' in kw self.zerofill = 'zerofill' in kw super(MSDouble, self).__init__(precision, length) + def get_col_spec(self): if self.precision is not None and self.length is not None: return "DOUBLE(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} else: return kw_colspec(self, "DOUBLE") + class MSFloat(sqltypes.Float): def __init__(self, precision=10, length=None, **kw): if length is not None: @@ -59,6 +64,7 @@ class MSFloat(sqltypes.Float): self.unsigned = 'unsigned' in kw self.zerofill = 'zerofill' in kw super(MSFloat, self).__init__(precision) + def get_col_spec(self): if hasattr(self, 'length') and self.length is not None: return kw_colspec(self, "FLOAT(%(precision)s,%(length)s)" % {'precision': self.precision, 'length' : self.length}) @@ -66,23 +72,27 @@ class MSFloat(sqltypes.Float): return kw_colspec(self, "FLOAT(%(precision)s)" % {'precision': self.precision}) else: return kw_colspec(self, "FLOAT") + class MSInteger(sqltypes.Integer): def __init__(self, length=None, **kw): self.length = length self.unsigned = 'unsigned' in kw self.zerofill = 'zerofill' in kw super(MSInteger, self).__init__() + def get_col_spec(self): if self.length is not None: return kw_colspec(self, "INTEGER(%(length)s)" % {'length': self.length}) else: return kw_colspec(self, "INTEGER") + class MSBigInteger(MSInteger): def get_col_spec(self): if self.length is not None: return kw_colspec(self, "BIGINT(%(length)s)" % {'length': self.length}) else: return kw_colspec(self, "BIGINT") + class MSSmallInteger(sqltypes.Smallinteger): def __init__(self, length=None, **kw): self.length = length @@ -94,54 +104,65 @@ class MSSmallInteger(sqltypes.Smallinteger): return kw_colspec(self, "SMALLINT(%(length)s)" % {'length': self.length}) else: return kw_colspec(self, "SMALLINT") + class MSDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATETIME" + class MSDate(sqltypes.Date): def get_col_spec(self): return "DATE" + class MSTime(sqltypes.Time): def get_col_spec(self): return "TIME" + def convert_result_value(self, value, dialect): # convert from a timedelta value if value is not None: return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) else: return None - + class MSText(sqltypes.TEXT): def __init__(self, **kw): self.binary = 'binary' in kw super(MSText, self).__init__() + def get_col_spec(self): return "TEXT" + class MSTinyText(MSText): def get_col_spec(self): if self.binary: return "TEXT BINARY" else: return "TEXT" + class MSMediumText(MSText): def get_col_spec(self): if self.binary: return "MEDIUMTEXT BINARY" else: return "MEDIUMTEXT" + class MSLongText(MSText): def get_col_spec(self): if self.binary: return "LONGTEXT BINARY" else: return "LONGTEXT" + class MSString(sqltypes.String): def __init__(self, length=None, *extra): sqltypes.String.__init__(self, length=length) def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} + class MSChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class MSBinary(sqltypes.Binary): def get_col_spec(self): if self.length is not None and self.length <=255: @@ -149,6 +170,7 @@ class MSBinary(sqltypes.Binary): return "BINARY(%d)" % self.length else: return "BLOB" + def convert_result_value(self, value, dialect): if value is None: return None @@ -158,7 +180,7 @@ class MSBinary(sqltypes.Binary): class MSMediumBlob(MSBinary): def get_col_spec(self): return "MEDIUMBLOB" - + class MSEnum(MSString): def __init__(self, *enums): self.__enums_hidden = enums @@ -172,17 +194,20 @@ class MSEnum(MSString): strip_enums.append(a) self.enums = strip_enums super(MSEnum, self).__init__(length) + def get_col_spec(self): return "ENUM(%s)" % ",".join(self.__enums_hidden) - + class MSBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" + def convert_result_value(self, value, dialect): if value is None: return None return value and True or False + def convert_bind_param(self, value, dialect): if value is True: return 1 @@ -192,7 +217,7 @@ class MSBoolean(sqltypes.Boolean): return None else: return value and True or False - + colspecs = { # sqltypes.BIGinteger : MSInteger, sqltypes.Integer : MSInteger, @@ -215,7 +240,7 @@ ischema_names = { 'int' : MSInteger, 'mediumint' : MSInteger, 'smallint' : MSSmallInteger, - 'tinyint' : MSSmallInteger, + 'tinyint' : MSSmallInteger, 'varchar' : MSString, 'char' : MSChar, 'text' : MSText, @@ -245,7 +270,6 @@ def descriptor(): ('host',"Hostname", None), ]} - class MySQLExecutionContext(default.DefaultExecutionContext): def post_exec(self, engine, proxy, compiled, parameters, **kwargs): if getattr(compiled, "isinsert", False): @@ -318,7 +342,6 @@ class MySQLDialect(ansisql.ANSIDialect): if o.args[0] == 2006 or o.args[0] == 2014: cursor.invalidate() raise o - def do_rollback(self, connection): # MySQL without InnoDB doesnt support rollback() @@ -331,7 +354,7 @@ class MySQLDialect(ansisql.ANSIDialect): if not hasattr(self, '_default_schema_name'): self._default_schema_name = text("select database()", self).scalar() return self._default_schema_name - + def dbapi(self): return self.module @@ -345,7 +368,7 @@ class MySQLDialect(ansisql.ANSIDialect): if isinstance(cs, array): cs = cs.tostring() case_sensitive = int(cs) == 0 - + if not case_sensitive: table.name = table.name.lower() table.metadata.tables[table.name]= table @@ -364,7 +387,7 @@ class MySQLDialect(ansisql.ANSIDialect): # these can come back as unicode if use_unicode=1 in the mysql connection (name, type, nullable, primary_key, default) = (str(row[0]), str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4]) - + match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type) col_type = match.group(1) args = match.group(2) @@ -391,7 +414,7 @@ class MySQLDialect(ansisql.ANSIDialect): colargs= [] if default: colargs.append(schema.PassiveDefault(sql.text(default))) - table.append_column(schema.Column(name, coltype, *colargs, + table.append_column(schema.Column(name, coltype, *colargs, **dict(primary_key=primary_key, nullable=nullable, ))) @@ -401,7 +424,7 @@ class MySQLDialect(ansisql.ANSIDialect): if not found_table: raise exceptions.NoSuchTableError(table.name) - + def moretableinfo(self, connection, table): """Return (tabletype, {colname:foreignkey,...}) execute(SHOW CREATE TABLE child) => @@ -438,10 +461,8 @@ class MySQLDialect(ansisql.ANSIDialect): table.append_constraint(constraint) return tabletype - class MySQLCompiler(ansisql.ANSICompiler): - def visit_cast(self, cast): """hey ho MySQL supports almost no types at all for CAST""" if (isinstance(cast.type, sqltypes.Date) or isinstance(cast.type, sqltypes.Time) or isinstance(cast.type, sqltypes.DateTime)): @@ -467,7 +488,7 @@ class MySQLCompiler(ansisql.ANSICompiler): text += " \n LIMIT 18446744073709551615" text += " OFFSET " + str(select.offset) return text - + class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): t = column.type.engine_impl(self.engine) @@ -495,6 +516,7 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX " + index.name + " ON " + index.table.name) self.execute() + def drop_foreignkey(self, constraint): self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name)) self.execute() @@ -502,9 +524,11 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def __init__(self, dialect): super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote='`') + def _escape_identifier(self, value): #TODO: determin MySQL's escaping rules return value + def _fold_identifier_case(self, value): #TODO: determin MySQL's case folding rules return value diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index d7b78d3dd..d53de0654 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -18,22 +18,25 @@ except: ORACLE_BINARY_TYPES = [getattr(cx_Oracle, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(cx_Oracle, k)] - class OracleNumeric(sqltypes.Numeric): def get_col_spec(self): if self.precision is None: return "NUMERIC" else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + class OracleInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" + class OracleSmallInteger(sqltypes.Smallinteger): def get_col_spec(self): return "SMALLINT" + class OracleDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATE" + # Note: # Oracle DATE == DATETIME # Oracle does not allow milliseconds in DATE @@ -43,32 +46,40 @@ class OracleDateTime(sqltypes.DateTime): class OracleTimestamp(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" + def get_dbapi_type(self, dialect): return dialect.TIMESTAMP - + class OracleText(sqltypes.TEXT): def get_col_spec(self): return "CLOB" + class OracleString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} + class OracleRaw(sqltypes.Binary): def get_col_spec(self): return "RAW(%(length)s)" % {'length' : self.length} + class OracleChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class OracleBinary(sqltypes.Binary): def get_dbapi_type(self, dbapi): return dbapi.BINARY + def get_col_spec(self): return "BLOB" + def convert_bind_param(self, value, dialect): if value is None: return None else: # this is RAWTOHEX return ''.join(["%.2X" % ord(c) for c in value]) + def convert_result_value(self, value, dialect): if value is None: return None @@ -78,10 +89,12 @@ class OracleBinary(sqltypes.Binary): class OracleBoolean(sqltypes.Boolean): def get_col_spec(self): return "SMALLINT" + def convert_result_value(self, value, dialect): if value is None: return None return value and True or False + def convert_bind_param(self, value, dialect): if value is True: return 1 @@ -90,9 +103,8 @@ class OracleBoolean(sqltypes.Boolean): elif value is None: return None else: - return value and True or False + return value and True or False - colspecs = { sqltypes.Integer : OracleInteger, sqltypes.Smallinteger : OracleSmallInteger, @@ -121,8 +133,6 @@ ischema_names = { 'DOUBLE PRECISION' : OracleNumeric, } - - def descriptor(): return {'name':'oracle', 'description':'Oracle', @@ -137,7 +147,7 @@ class OracleExecutionContext(default.DefaultExecutionContext): super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters) if self.dialect.auto_setinputsizes: self.set_input_sizes(proxy(), parameters) - + class OracleDialect(ansisql.ANSIDialect): def __init__(self, use_ansi=True, auto_setinputsizes=False, module=None, threaded=True, **kwargs): self.use_ansi = use_ansi @@ -173,7 +183,7 @@ class OracleDialect(ansisql.ANSIDialect): ) opts.update(url.query) return ([], opts) - + def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) @@ -188,14 +198,16 @@ class OracleDialect(ansisql.ANSIDialect): def compiler(self, statement, bindparams, **kwargs): return OracleCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): return OracleSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): return OracleSchemaDropper(*args, **kwargs) + def defaultrunner(self, engine, proxy): return OracleDefaultRunner(engine, proxy) - def has_table(self, connection, table_name, schema=None): cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()}) return bool( cursor.fetchone() is not None ) @@ -229,8 +241,12 @@ class OracleDialect(ansisql.ANSIDialect): raise exceptions.AssertionError("There are multiple tables with name '%s' visible to the schema, you must specifiy owner" % name) else: return None + def _resolve_table_owner(self, connection, name, table, dblink=''): - """locate the given table in the ALL_TAB_COLUMNS view, including searching for equivalent synonyms and dblinks""" + """Locate the given table in the ``ALL_TAB_COLUMNS`` view, + including searching for equivalent synonyms and dblinks. + """ + c = connection.execute ("select distinct OWNER from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name" % {'dblink':dblink}, {'table_name':name}) rows = c.fetchall() try: @@ -239,10 +255,10 @@ class OracleDialect(ansisql.ANSIDialect): except exceptions.SQLAlchemyError: # locate synonyms c = connection.execute ("""select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK - from ALL_SYNONYMS%(dblink)s + from ALL_SYNONYMS%(dblink)s where SYNONYM_NAME = :synonym_name - and (DB_LINK IS NOT NULL - or ((TABLE_NAME, TABLE_OWNER) in + and (DB_LINK IS NOT NULL + or ((TABLE_NAME, TABLE_OWNER) in (select TABLE_NAME, OWNER from ALL_TAB_COLUMNS%(dblink)s)))""" % {'dblink':dblink}, {'synonym_name':name}) rows = c.fetchall() @@ -262,20 +278,19 @@ class OracleDialect(ansisql.ANSIDialect): return name, owner, dblink raise - def reflecttable(self, connection, table): preparer = self.identifier_preparer if not preparer.should_quote(table): name = table.name.upper() else: name = table.name - + # search for table, including across synonyms and dblinks. # locate the actual name of the table, the real owner, and any dblink clause needed. actual_name, owner, dblink = self._resolve_table_owner(connection, name, table) - + c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner}) - + while True: row = c.fetchone() if row is None: @@ -305,20 +320,20 @@ class OracleDialect(ansisql.ANSIDialect): coltype = ischema_names[coltype] except KeyError: raise exceptions.AssertionError("Cant get coltype for type '%s' on colname '%s'" % (coltype, colname)) - + colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - - # if name comes back as all upper, assume its case folded - if (colname.upper() == colname): + + # if name comes back as all upper, assume its case folded + if (colname.upper() == colname): colname = colname.lower() - + table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) if not len(table.columns): raise exceptions.AssertionError("Couldn't find any column information for table %s" % actual_name) - + c = connection.execute("""SELECT ac.constraint_name, ac.constraint_type, @@ -339,13 +354,13 @@ class OracleDialect(ansisql.ANSIDialect): -- order multiple primary keys correctly ORDER BY ac.constraint_name, loc.position, rem.position""" % {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner}) - + fks = {} while True: row = c.fetchone() if row is None: break - #print "ROW:" , row + #print "ROW:" , row (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row if cons_type == 'P': table.primary_key.add(table.c[local_column]) @@ -389,12 +404,17 @@ class OracleDialect(ansisql.ANSIDialect): OracleDialect.logger = logging.class_logger(OracleDialect) class OracleCompiler(ansisql.ANSICompiler): - """oracle compiler modifies the lexical structure of Select statements to work under - non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + def default_from(self): - """called when a SELECT statement has no froms, and no FROM clause is to be appended. - gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ + """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + return " FROM DUAL" def apply_function_parens(self, func): @@ -403,7 +423,7 @@ class OracleCompiler(ansisql.ANSICompiler): def visit_join(self, join): if self.dialect.use_ansi: return ansisql.ANSICompiler.visit_join(self, join) - + self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right) self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause) self.strings[join] = self.froms[join] @@ -421,42 +441,50 @@ class OracleCompiler(ansisql.ANSICompiler): self.visit_compound(self.wheres[join]) def visit_insert_sequence(self, column, sequence, parameters): - """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures - that the column is present in the generated column list""" + """This is the `sequence` equivalent to ``ANSICompiler``'s + `visit_insert_column_default` which ensures that the column is + present in the generated column list. + """ + parameters.setdefault(column.key, None) - + def visit_alias(self, alias): - """oracle doesnt like 'FROM table AS alias'. is the AS standard SQL??""" + """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??""" + self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name self.strings[alias] = self.get_str(alias.original) - + def visit_column(self, column): ansisql.ANSICompiler.visit_column(self, column) if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable: self.strings[column] = self.strings[column] + "(+)" - + def visit_insert(self, insert): - """inserts are required to have the primary keys be explicitly present. - mapper will by default not put them in the insert statement to comply - with autoincrement fields that require they not be present. so, - put them all in for all primary key columns.""" + """``INSERT``s are required to have the primary keys be explicitly present. + + Mapper will by default not put them in the insert statement + to comply with autoincrement fields that require they not be + present. so, put them all in for all primary key columns. + """ + for c in insert.table.primary_key: if not self.parameters.has_key(c.key): self.parameters[c.key] = None return ansisql.ANSICompiler.visit_insert(self, insert) def _TODO_visit_compound_select(self, select): - """need to determine how to get LIMIT/OFFSET into a UNION for oracle""" + """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" + if getattr(select, '_oracle_visit', False): # cancel out the compiled order_by on the select if hasattr(select, "order_by_clause"): self.strings[select.order_by_clause] = "" ansisql.ANSICompiler.visit_compound_select(self, select) return - + if select.limit is not None or select.offset is not None: select._oracle_visit = True - # to use ROW_NUMBER(), an ORDER BY is required. + # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.strings[select.order_by_clause] if not orderby: orderby = select.oid_column @@ -478,10 +506,12 @@ class OracleCompiler(ansisql.ANSICompiler): self.froms[select] = self.froms[limitselect] else: ansisql.ANSICompiler.visit_compound_select(self, select) - + def visit_select(self, select): - """looks for LIMIT and OFFSET in a select statement, and if so tries to wrap it in a - subquery with row_number() criterion.""" + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + """ + # TODO: put a real copy-container on Select and copy, or somehow make this # not modify the Select statement if getattr(select, '_oracle_visit', False): @@ -493,7 +523,7 @@ class OracleCompiler(ansisql.ANSICompiler): if select.limit is not None or select.offset is not None: select._oracle_visit = True - # to use ROW_NUMBER(), an ORDER BY is required. + # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.strings[select.order_by_clause] if not orderby: orderby = select.oid_column @@ -512,7 +542,7 @@ class OracleCompiler(ansisql.ANSICompiler): self.froms[select] = self.froms[limitselect] else: ansisql.ANSICompiler.visit_select(self, select) - + def limit_clause(self, select): return "" @@ -539,7 +569,6 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() - class OracleSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): if self.engine.dialect.has_sequence(self.connection, sequence.name): @@ -550,7 +579,7 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile() return self.proxy(str(c), c.get_params()).fetchone()[0] - + def visit_sequence(self, seq): return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0] diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index b76aafc22..83dac516a 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -15,7 +15,7 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions from sqlalchemy.databases import information_schema as ischema -from sqlalchemy import * +from sqlalchemy import * import re try: @@ -42,24 +42,30 @@ class PGNumeric(sqltypes.Numeric): return "NUMERIC" else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + class PGFloat(sqltypes.Float): def get_col_spec(self): if not self.precision: return "FLOAT" else: return "FLOAT(%(precision)s)" % {'precision': self.precision} + class PGInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" + class PGSmallInteger(sqltypes.Smallinteger): def get_col_spec(self): return "SMALLINT" + class PGBigInteger(PGInteger): def get_col_spec(self): return "BIGINT" + class PG2DateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + class PG1DateTime(sqltypes.DateTime): def convert_bind_param(self, value, dialect): if value is not None: @@ -73,6 +79,7 @@ class PG1DateTime(sqltypes.DateTime): return psycopg.TimestampFromMx(value) else: return None + def convert_result_value(self, value, dialect): if value is None: return None @@ -82,11 +89,14 @@ class PG1DateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day, value.hour, value.minute, seconds, microseconds) + def get_col_spec(self): return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + class PG2Date(sqltypes.Date): def get_col_spec(self): return "DATE" + class PG1Date(sqltypes.Date): def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime @@ -95,14 +105,18 @@ class PG1Date(sqltypes.Date): return psycopg.DateFromMx(value) else: return None + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value + def get_col_spec(self): return "DATE" + class PG2Time(sqltypes.Time): def get_col_spec(self): return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + class PG1Time(sqltypes.Time): def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime @@ -111,32 +125,38 @@ class PG1Time(sqltypes.Time): return psycopg.TimeFromMx(value) else: return None + def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value + def get_col_spec(self): return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + class PGInterval(sqltypes.TypeEngine): def get_col_spec(self): return "INTERVAL" - + class PGText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" + class PGString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} + class PGChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class PGBinary(sqltypes.Binary): def get_col_spec(self): return "BYTEA" + class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" - pg2_colspecs = { sqltypes.Integer : PGInteger, sqltypes.Smallinteger : PGSmallInteger, @@ -214,7 +234,7 @@ class PGExecutionContext(default.DefaultExecutionContext): cursor = proxy(str(c), c.get_params()) row = cursor.fetchone() self._last_inserted_ids = [v for v in row] - + class PGDialect(ansisql.ANSIDialect): def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params): self.use_oids = use_oids @@ -225,7 +245,7 @@ class PGDialect(ansisql.ANSIDialect): self.module = psycopg else: self.module = module - # figure psycopg version 1 or 2 + # figure psycopg version 1 or 2 try: if self.module.__version__.startswith('2'): self.version = 2 @@ -238,7 +258,7 @@ class PGDialect(ansisql.ANSIDialect): # produce consistent paramstyle even if psycopg2 module not present if self.module is None: self.paramstyle = 'pyformat' - + def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) if opts.has_key('port'): @@ -265,23 +285,27 @@ class PGDialect(ansisql.ANSIDialect): return sqltypes.adapt_type(typeobj, pg2_colspecs) else: return sqltypes.adapt_type(typeobj, pg1_colspecs) - + def compiler(self, statement, bindparams, **kwargs): return PGCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): return PGSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): return PGSchemaDropper(*args, **kwargs) + def defaultrunner(self, engine, proxy): return PGDefaultRunner(engine, proxy) + def preparer(self): return PGIdentifierPreparer(self) - + def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): self._default_schema_name = connection.scalar("select current_schema()", None) return self._default_schema_name - + def last_inserted_ids(self): if self.context.last_inserted_ids is None: raise exceptions.InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") @@ -295,8 +319,12 @@ class PGDialect(ansisql.ANSIDialect): return None def do_executemany(self, c, statement, parameters, context=None): - """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough - to produce this correctly for an executemany, so we do our own executemany here.""" + """We need accurate rowcounts for updates, inserts and deletes. + + ``psycopg2`` is not nice enough to produce this correctly for + an executemany, so we do our own executemany here. + """ + rowcount = 0 for param in parameters: c.execute(statement, param) @@ -318,7 +346,7 @@ class PGDialect(ansisql.ANSIDialect): def has_sequence(self, connection, sequence_name): cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name}) return bool(not not cursor.rowcount) - + def reflecttable(self, connection, table): if self.version == 2: ischema_names = pg2_ischema_names @@ -333,10 +361,10 @@ class PGDialect(ansisql.ANSIDialect): schema_where_clause = "n.nspname = :schema" else: schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - + ## information schema in pg suffers from too many permissions' restrictions ## let us find out at the pg way what is needed... - + SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), @@ -354,25 +382,25 @@ class PGDialect(ansisql.ANSIDialect): ) AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum """ % schema_where_clause - + s = text(SQL_COLS) - c = connection.execute(s, table_name=table.name, + c = connection.execute(s, table_name=table.name, schema=table.schema) rows = c.fetchall() - - if not rows: + + if not rows: raise exceptions.NoSuchTableError(table.name) - + for name, format_type, default, notnull, attnum, table_oid in rows: - ## strip (30) from character varying(30) + ## strip (30) from character varying(30) attype = re.search('([^\(]+)', format_type).group(1) nullable = not notnull - + try: charlen = re.search('\(([\d,]+)\)', format_type).group(1) except: charlen = False - + numericprec = False numericscale = False if attype == 'numeric': @@ -400,7 +428,7 @@ class PGDialect(ansisql.ANSIDialect): kwargs['timezone'] = True elif attype == 'timestamp without time zone': kwargs['timezone'] = False - + coltype = ischema_names[attype] coltype = coltype(*args, **kwargs) colargs= [] @@ -413,31 +441,31 @@ class PGDialect(ansisql.ANSIDialect): default = match.group(1) + sch + '.' + match.group(2) + match.group(3) colargs.append(PassiveDefault(sql.text(default))) table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) - - + + # Primary keys PK_SQL = """ - SELECT attname FROM pg_attribute + SELECT attname FROM pg_attribute WHERE attrelid = ( SELECT indexrelid FROM pg_index i WHERE i.indrelid = :table AND i.indisprimary = 't') ORDER BY attnum - """ + """ t = text(PK_SQL) c = connection.execute(t, table=table_oid) - for row in c.fetchall(): + for row in c.fetchall(): pk = row[0] table.primary_key.add(table.c[pk]) - + # Foreign keys FK_SQL = """ - SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table AND r.contype = 'f' + SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table AND r.contype = 'f' ORDER BY 1 """ - + t = text(FK_SQL) c = connection.execute(t, table=table_oid) for conname, condef in c.fetchall(): @@ -448,10 +476,10 @@ class PGDialect(ansisql.ANSIDialect): referred_schema = preparer._unquote_identifier(referred_schema) referred_table = preparer._unquote_identifier(referred_table) referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] - + refspec = [] if referred_schema is not None: - schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, + schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection) for column in referred_columns: refspec.append(".".join([referred_schema, referred_table, column])) @@ -459,11 +487,10 @@ class PGDialect(ansisql.ANSIDialect): schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) for column in referred_columns: refspec.append(".".join([referred_table, column])) - + table.append_constraint(ForeignKeyConstraint(constrained_columns, refspec, conname)) class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column, parameters): # all column primary key inserts must be explicitly present if column.primary_key: @@ -502,10 +529,9 @@ class PGCompiler(ansisql.ANSICompiler): if isinstance(binary.type, sqltypes.String) and binary.operator == '+': return '||' else: - return ansisql.ANSICompiler.binary_operator_string(self, binary) - + return ansisql.ANSICompiler.binary_operator_string(self, binary) + class PGSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): @@ -527,7 +553,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() - + class PGSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)): @@ -543,7 +569,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): return c.fetchone()[0] elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema - # TODO: this has to build into the Sequence object so we can get the quoting + # TODO: this has to build into the Sequence object so we can get the quoting # logic from it if sch is not None: exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) @@ -555,7 +581,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): return ansisql.ANSIDefaultRunner.get_column_default(self, column) else: return ansisql.ANSIDefaultRunner.get_column_default(self, column) - + def visit_sequence(self, seq): if not seq.optional: c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq)) @@ -566,9 +592,10 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def _fold_identifier_case(self, value): return value.lower() + def _unquote_identifier(self, value): if value[0] == self.initial_quote: value = value[1:-1].replace('""','"') return value - + dialect = PGDialect diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 2ab3c0d5a..b29be9eed 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -31,18 +31,22 @@ class SLNumeric(sqltypes.Numeric): return "NUMERIC" else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + class SLInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" + class SLSmallInteger(sqltypes.Smallinteger): def get_col_spec(self): return "SMALLINT" + class DateTimeMixin(object): def convert_bind_param(self, value, dialect): if value is not None: return str(value) else: return None + def _cvt(self, value, dialect, fmt): if value is None: return None @@ -52,49 +56,61 @@ class DateTimeMixin(object): except ValueError: (value, microsecond) = (value, 0) return time.strptime(value, fmt)[0:6] + (microsecond,) - + class SLDateTime(DateTimeMixin,sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP" + def convert_result_value(self, value, dialect): tup = self._cvt(value, dialect, "%Y-%m-%d %H:%M:%S") return tup and datetime.datetime(*tup) + class SLDate(DateTimeMixin, sqltypes.Date): def get_col_spec(self): return "DATE" + def convert_result_value(self, value, dialect): tup = self._cvt(value, dialect, "%Y-%m-%d") return tup and datetime.date(*tup[0:3]) + class SLTime(DateTimeMixin, sqltypes.Time): def get_col_spec(self): return "TIME" + def convert_result_value(self, value, dialect): tup = self._cvt(value, dialect, "%H:%M:%S") return tup and datetime.time(*tup[3:7]) + class SLText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" + class SLString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} + class SLChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} + class SLBinary(sqltypes.Binary): def get_col_spec(self): return "BLOB" + class SLBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" + def convert_bind_param(self, value, dialect): if value is None: return None return value and 1 or 0 + def convert_result_value(self, value, dialect): if value is None: return None return value and True or False - + colspecs = { sqltypes.Integer : SLInteger, sqltypes.Smallinteger : SLSmallInteger, @@ -135,49 +151,56 @@ def descriptor(): ('database', "Database Filename",None) ]} - class SQLiteExecutionContext(default.DefaultExecutionContext): def post_exec(self, engine, proxy, compiled, parameters, **kwargs): if getattr(compiled, "isinsert", False): self._last_inserted_ids = [proxy().lastrowid] - + class SQLiteDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): def vers(num): return tuple([int(x) for x in num.split('.')]) self.supports_cast = (sqlite is not None and vers(sqlite.sqlite_version) >= vers("3.2.3")) ansisql.ANSIDialect.__init__(self, **kwargs) + def compiler(self, statement, bindparams, **kwargs): return SQLiteCompiler(self, statement, bindparams, **kwargs) + def schemagenerator(self, *args, **kwargs): return SQLiteSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): return SQLiteSchemaDropper(*args, **kwargs) + def preparer(self): return SQLiteIdentifierPreparer(self) + def create_connect_args(self, url): filename = url.database or ':memory:' return ([filename], url.query) + def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) + def create_execution_context(self): return SQLiteExecutionContext(self) + def last_inserted_ids(self): return self.context.last_inserted_ids - + def oid_column_name(self, column): return "oid" def dbapi(self): return sqlite - + def has_table(self, connection, table_name, schema=None): cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {}) row = cursor.fetchone() - + # consume remaining rows, to work around: http://www.sqlite.org/cvstrac/tktview?tn=1884 while cursor.fetchone() is not None:pass - + return (row is not None) def reflecttable(self, connection, table): @@ -198,7 +221,7 @@ class SQLiteDialect(ansisql.ANSIDialect): else: coltype = "VARCHAR" args = '' - + #print "coltype: " + repr(coltype) + " args: " + repr(args) coltype = pragma_names.get(coltype, SLString) if args is not None: @@ -210,10 +233,10 @@ class SQLiteDialect(ansisql.ANSIDialect): if has_default: colargs.append(PassiveDefault('?')) table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) - + if not found_table: raise exceptions.NoSuchTableError(table.name) - + c = connection.execute("PRAGMA foreign_key_list(" + table.name + ")", {}) fks = {} while True: @@ -229,7 +252,7 @@ class SQLiteDialect(ansisql.ANSIDialect): except KeyError: fk = ([],[]) fks[constraint_name] = fk - + #print "row! " + repr([key for key in row.keys()]), repr(row) # look up the table based on the given table's engine, not 'self', # since it could be a ProxyEngine @@ -241,7 +264,7 @@ class SQLiteDialect(ansisql.ANSIDialect): if refspec not in fk[1]: fk[1].append(refspec) for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1])) + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1])) # check for UNIQUE indexes c = connection.execute("PRAGMA index_list(" + table.name + ")", {}) unique_indexes = [] @@ -264,7 +287,7 @@ class SQLiteDialect(ansisql.ANSIDialect): # unique index that includes the pk is considered a multiple primary key for col in cols: table.primary_key.add(table.columns[col]) - + class SQLiteCompiler(ansisql.ANSICompiler): def visit_cast(self, cast): if self.dialect.supports_cast: @@ -274,6 +297,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): # not sure if we want to set the typemap here... self.typemap.setdefault("CAST", cast.type) self.strings[cast] = self.strings[cast.clause] + def limit_clause(self, select): text = "" if select.limit is not None: @@ -285,6 +309,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): else: text += " OFFSET 0" return text + def for_update_clause(self, select): # sqlite has no "FOR UPDATE" AFAICT return '' @@ -298,7 +323,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def supports_alter(self): return False - + def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) @@ -328,4 +353,4 @@ class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True) dialect = SQLiteDialect -poolclass = pool.SingletonThreadPool +poolclass = pool.SingletonThreadPool |
