diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-04-02 21:36:11 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-04-02 21:36:11 +0000 |
| commit | cdceb3c3714af707bfe3ede10af6536eaf529ca8 (patch) | |
| tree | 2ccbfb60cd10d995c0309801b0adc4fc3a1f0a44 /lib/sqlalchemy/databases | |
| parent | 8607de3159fd37923ae99118c499935c4a54d0e2 (diff) | |
| download | sqlalchemy-cdceb3c3714af707bfe3ede10af6536eaf529ca8.tar.gz | |
- merged the "execcontext" branch, refactors engine/dialect codepaths
- much more functionality moved into ExecutionContext, which impacted
the API used by dialects to some degree
- ResultProxy and subclasses now designed sanely
- merged patch for #522, Unicode subclasses String directly,
MSNVarchar implements for MS-SQL, removed MSUnicode.
- String moves its "VARCHAR"/"TEXT" switchy thing into
"get_search_list()" function, which VARCHAR and CHAR can override
to not return TEXT in any case (didnt do the latter yet)
- implements server side cursors for postgres, unit tests, #514
- includes overhaul of dbapi import strategy #480, all dbapi
importing happens in dialect method "dbapi()", is only called
inside of create_engine() for default and threadlocal strategies.
Dialect subclasses have a datamember "dbapi" referencing the loaded
module which may be None.
- added "mock" engine strategy, doesnt require DBAPI module and
gives you a "Connecition" which just sends all executes to a callable.
can be used to create string output of create_all()/drop_all().
Diffstat (limited to 'lib/sqlalchemy/databases')
| -rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 73 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 159 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 50 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 75 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 138 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/sqlite.py | 60 |
6 files changed, 238 insertions, 317 deletions
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 91a0869c6..2ab88101a 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -15,12 +15,9 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions -try: +def dbapi(): import kinterbasdb -except: - kinterbasdb = None - -dbmodule = kinterbasdb + return kinterbasdb _initialized_kb = False @@ -33,7 +30,6 @@ class FBNumeric(sqltypes.Numeric): return "NUMERIC(%(precision)s, %(length)s)" % { 'precision': self.precision, 'length' : self.length } - class FBInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" @@ -111,24 +107,11 @@ class FBExecutionContext(default.DefaultExecutionContext): class FBDialect(ansisql.ANSIDialect): - def __init__(self, module = None, **params): - global _initialized_kb - self.module = module or dbmodule - self.opts = {} - - if not _initialized_kb: - _initialized_kb = True - type_conv = params.get('type_conv', 200) or 200 - if isinstance(type_conv, types.StringTypes): - type_conv = int(type_conv) - - concurrency_level = params.get('concurrency_level', 1) or 1 - if isinstance(concurrency_level, types.StringTypes): - concurrency_level = int(concurrency_level) + def __init__(self, type_conv=200, concurrency_level=1, **kwargs): + ansisql.ANSIDialect.__init__(self, **kwargs) - if kinterbasdb is not None: - kinterbasdb.init(type_conv=type_conv, concurrency_level=concurrency_level) - ansisql.ANSIDialect.__init__(self, **params) + self.type_conv = type_conv + self.concurrency_level= concurrency_level def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) @@ -136,15 +119,17 @@ class FBDialect(ansisql.ANSIDialect): opts['host'] = "%s/%s" % (opts['host'], opts['port']) del opts['port'] opts.update(url.query) - # pop arguments that we took at the module level - opts.pop('type_conv', None) - opts.pop('concurrency_level', None) - self.opts = opts - return ([], self.opts) + type_conv = opts.pop('type_conv', self.type_conv) + concurrency_level = opts.pop('concurrency_level', self.concurrency_level) + global _initialized_kb + if not _initialized_kb and self.dbapi is not None: + _initialized_kb = True + self.dbapi.init(type_conv=type_conv, concurrency_level=concurrency_level) + return ([], opts) - def create_execution_context(self): - return FBExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return FBExecutionContext(self, *args, **kwargs) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) @@ -156,13 +141,13 @@ class FBDialect(ansisql.ANSIDialect): return FBCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return FBSchemaGenerator(*args, **kwargs) + return FBSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return FBSchemaDropper(*args, **kwargs) + return FBSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return FBDefaultRunner(engine, proxy) + def defaultrunner(self, connection): + return FBDefaultRunner(connection) def preparer(self): return FBIdentifierPreparer(self) @@ -292,9 +277,6 @@ class FBDialect(ansisql.ANSIDialect): for name,value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) - def last_inserted_ids(self): - return self.context.last_inserted_ids - def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters or []) @@ -304,15 +286,6 @@ class FBDialect(ansisql.ANSIDialect): def do_commit(self, connection): connection.commit(True) - def connection(self): - """Returns a managed DBAPI connection from this SQLEngine's connection pool.""" - c = self._pool.connect() - c.supportsTransactions = 0 - return c - - def dbapi(self): - return self.module - class FBCompiler(ansisql.ANSICompiler): """Firebird specific idiosincrasies""" @@ -364,7 +337,7 @@ class FBCompiler(ansisql.ANSICompiler): class FBSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.engine_impl(self.engine).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: @@ -388,11 +361,11 @@ class FBSchemaDropper(ansisql.ANSISchemaDropper): class FBDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): - c = sql.select([default.arg], from_obj=["rdb$database"], engine=self.engine).compile() - return self.proxy(str(c), c.get_params()).fetchone()[0] + c = sql.select([default.arg], from_obj=["rdb$database"]).compile(engine=self.engine) + return self.connection.execute_compiled(c).scalar() def visit_sequence(self, seq): - return self.proxy("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").fetchone()[0] + return self.connection.execute_text("SELECT gen_id(" + seq.name + ", 1) FROM rdb$database").scalar() RESERVED_WORDS = util.Set( diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 1852edefb..6d2ff66cd 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -52,7 +52,22 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions - +def dbapi(module_name=None): + if module_name: + try: + dialect_cls = dialect_mapping[module_name] + return dialect_cls.import_dbapi() + except KeyError: + raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) + else: + for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]: + try: + return dialect_cls.import_dbapi() + except ImportError, e: + pass + else: + raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') + class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): return value @@ -142,9 +157,6 @@ class MSString(sqltypes.String): return "VARCHAR(%(length)s)" % {'length' : self.length} class MSNVarchar(MSString): - """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True. """ - impl = sqltypes.Unicode - def get_col_spec(self): if self.length: return "NVARCHAR(%(length)s)" % {'length' : self.length} @@ -154,19 +166,7 @@ class MSNVarchar(MSString): return "NTEXT" class AdoMSNVarchar(MSNVarchar): - def convert_bind_param(self, value, dialect): - return value - - def convert_result_value(self, value, dialect): - return value - -class MSUnicode(sqltypes.Unicode): - """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl.""" - impl = MSNVarchar - -class AdoMSUnicode(MSUnicode): - impl = AdoMSNVarchar - + """overrides bindparam/result processing to not convert any unicode strings""" def convert_bind_param(self, value, dialect): return value @@ -215,9 +215,9 @@ def descriptor(): ]} class MSSQLExecutionContext(default.DefaultExecutionContext): - def __init__(self, dialect): + def __init__(self, *args, **kwargs): self.IINSERT = self.HASIDENT = False - super(MSSQLExecutionContext, self).__init__(dialect) + super(MSSQLExecutionContext, self).__init__(*args, **kwargs) def _has_implicit_sequence(self, column): if column.primary_key and column.autoincrement: @@ -227,14 +227,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): return True return False - def pre_exec(self, engine, proxy, compiled, parameters, **kwargs): + def pre_exec(self): """MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if the feature is turned on and needed. """ - if getattr(compiled, "isinsert", False): - tbl = compiled.statement.table + if self.compiled.isinsert: + tbl = self.compiled.statement.table if not hasattr(tbl, 'has_sequence'): tbl.has_sequence = None for column in tbl.c: @@ -243,39 +243,43 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): break self.HASIDENT = bool(tbl.has_sequence) - if engine.dialect.auto_identity_insert and self.HASIDENT: - if isinstance(parameters, list): - self.IINSERT = tbl.has_sequence.key in parameters[0] + if self.dialect.auto_identity_insert and self.HASIDENT: + if isinstance(self.compiled_parameters, list): + self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0] else: - self.IINSERT = tbl.has_sequence.key in parameters + self.IINSERT = tbl.has_sequence.key in self.compiled_parameters else: self.IINSERT = False if self.IINSERT: - proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name) + # TODO: quoting rules for table name here ? + self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name) - super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs) + super(MSSQLExecutionContext, self).pre_exec() - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): + def post_exec(self): """Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column). """ - if getattr(compiled, "isinsert", False): + if self.compiled.isinsert: if self.IINSERT: - proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name) + # TODO: quoting rules for table name here ? + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name) self.IINSERT = False elif self.HASIDENT: - cursor = proxy("SELECT @@IDENTITY AS lastrowid") - row = cursor.fetchone() + self.cursor.execute("SELECT @@IDENTITY AS lastrowid") + row = self.cursor.fetchone() self._last_inserted_ids = [int(row[0])] # print "LAST ROW ID", self._last_inserted_ids self.HASIDENT = False + super(MSSQLExecutionContext, self).post_exec() class MSSQLDialect(ansisql.ANSIDialect): colspecs = { + sqltypes.Unicode : MSNVarchar, sqltypes.Integer : MSInteger, sqltypes.Smallinteger: MSSmallInteger, sqltypes.Numeric : MSNumeric, @@ -283,7 +287,6 @@ class MSSQLDialect(ansisql.ANSIDialect): sqltypes.DateTime : MSDateTime, sqltypes.Date : MSDate, sqltypes.String : MSString, - sqltypes.Unicode : MSUnicode, sqltypes.Binary : MSBinary, sqltypes.Boolean : MSBoolean, sqltypes.TEXT : MSText, @@ -296,7 +299,7 @@ class MSSQLDialect(ansisql.ANSIDialect): 'smallint' : MSSmallInteger, 'tinyint' : MSTinyInteger, 'varchar' : MSString, - 'nvarchar' : MSUnicode, + 'nvarchar' : MSNVarchar, 'char' : MSChar, 'nchar' : MSNChar, 'text' : MSText, @@ -312,30 +315,16 @@ class MSSQLDialect(ansisql.ANSIDialect): 'image' : MSBinary } - def __new__(cls, module_name=None, *args, **kwargs): - module = kwargs.get('module', None) + def __new__(cls, dbapi=None, *args, **kwargs): if cls != MSSQLDialect: return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs) - if module_name: - dialect = dialect_mapping.get(module_name) - if not dialect: - raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name) - if not hasattr(dialect, 'module'): - raise dialect.saved_import_error + if dbapi: + dialect = dialect_mapping.get(dbapi.__name__) return dialect(*args, **kwargs) - elif module: - return object.__new__(cls, *args, **kwargs) else: - for dialect in dialect_preference: - if hasattr(dialect, 'module'): - return dialect(*args, **kwargs) - #raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') - else: - return object.__new__(cls, *args, **kwargs) + return object.__new__(cls, *args, **kwargs) - def __init__(self, module_name=None, module=None, auto_identity_insert=True, **params): - if not hasattr(self, 'module'): - self.module = module + def __init__(self, auto_identity_insert=True, **params): super(MSSQLDialect, self).__init__(**params) self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False @@ -352,8 +341,8 @@ class MSSQLDialect(ansisql.ANSIDialect): self.text_as_varchar = bool(opts.pop('text_as_varchar')) return self.make_connect_string(opts) - def create_execution_context(self): - return MSSQLExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return MSSQLExecutionContext(self, *args, **kwargs) def type_descriptor(self, typeobj): newobj = sqltypes.adapt_type(typeobj, self.colspecs) @@ -373,13 +362,13 @@ class MSSQLDialect(ansisql.ANSIDialect): return MSSQLCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return MSSQLSchemaGenerator(*args, **kwargs) + return MSSQLSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return MSSQLSchemaDropper(*args, **kwargs) + return MSSQLSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return MSSQLDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return MSSQLDefaultRunner(connection, **kwargs) def preparer(self): return MSSQLIdentifierPreparer(self) @@ -411,19 +400,12 @@ class MSSQLDialect(ansisql.ANSIDialect): def raw_connection(self, connection): """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes""" try: + # TODO: probably want to move this to individual dialect subclasses to + # save on the exception throw + simplify return connection.connection.__dict__['_pymssqlCnx__cnx'] except: return connection.connection.adoConn - def connection(self): - """returns a managed DBAPI connection from this SQLEngine's connection pool.""" - c = self._pool.connect() - c.supportsTransactions = 0 - return c - - def dbapi(self): - return self.module - def uppercase_table(self, t): # convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive t.name = t.name.upper() @@ -558,13 +540,14 @@ class MSSQLDialect(ansisql.ANSIDialect): table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) class MSSQLDialect_pymssql(MSSQLDialect): - try: + def import_dbapi(cls): import pymssql as module # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal module.Binary = lambda st: str(st) - except ImportError, e: - saved_import_error = e - + return module + import_dbapi = classmethod(import_dbapi) + def supports_sane_rowcount(self): return True @@ -578,7 +561,7 @@ class MSSQLDialect_pymssql(MSSQLDialect): def create_connect_args(self, url): r = super(MSSQLDialect_pymssql, self).create_connect_args(url) if hasattr(self, 'query_timeout'): - self.module._mssql.set_query_timeout(self.query_timeout) + self.dbapi._mssql.set_query_timeout(self.query_timeout) return r def make_connect_string(self, keys): @@ -621,15 +604,16 @@ class MSSQLDialect_pymssql(MSSQLDialect): ## r.fetch_array() class MSSQLDialect_pyodbc(MSSQLDialect): - try: + + def import_dbapi(cls): import pyodbc as module - except ImportError, e: - saved_import_error = e - + return module + import_dbapi = classmethod(import_dbapi) + colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSUnicode + colspecs[sqltypes.Unicode] = AdoMSNVarchar ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSUnicode + ischema_names['nvarchar'] = AdoMSNVarchar def supports_sane_rowcount(self): return False @@ -648,15 +632,15 @@ class MSSQLDialect_pyodbc(MSSQLDialect): class MSSQLDialect_adodbapi(MSSQLDialect): - try: + def import_dbapi(cls): import adodbapi as module - except ImportError, e: - saved_import_error = e + return module + import_dbapi = classmethod(import_dbapi) colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.Unicode] = AdoMSUnicode + colspecs[sqltypes.Unicode] = AdoMSNVarchar ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['nvarchar'] = AdoMSUnicode + ischema_names['nvarchar'] = AdoMSNVarchar def supports_sane_rowcount(self): return True @@ -676,13 +660,11 @@ class MSSQLDialect_adodbapi(MSSQLDialect): connectors.append("Integrated Security=SSPI") return [[";".join (connectors)], {}] - dialect_mapping = { 'pymssql': MSSQLDialect_pymssql, 'pyodbc': MSSQLDialect_pyodbc, 'adodbapi': MSSQLDialect_adodbapi } -dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc] class MSSQLCompiler(ansisql.ANSICompiler): @@ -770,7 +752,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() # install a IDENTITY Sequence if we have an implicit IDENTITY column if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ @@ -797,6 +779,7 @@ class MSSQLSchemaDropper(ansisql.ANSISchemaDropper): self.execute() class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner): + # TODO: does ms-sql have standalone sequences ? pass class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 5fc63234a..65ccb6af1 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -12,12 +12,9 @@ import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions from array import array -try: +def dbapi(): import MySQLdb as mysql - import MySQLdb.constants.CLIENT as CLIENT_FLAGS -except: - mysql = None - CLIENT_FLAGS = None + return mysql def kw_colspec(self, spec): if self.unsigned: @@ -158,8 +155,6 @@ class MSLongText(MSText): return "LONGTEXT" class MSString(sqltypes.String): - def __init__(self, length=None, *extra, **kwargs): - sqltypes.String.__init__(self, length=length) def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} @@ -277,16 +272,12 @@ def descriptor(): ]} class MySQLExecutionContext(default.DefaultExecutionContext): - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self._last_inserted_ids = [proxy().lastrowid] + def post_exec(self): + if self.compiled.isinsert: + self._last_inserted_ids = [self.cursor.lastrowid] class MySQLDialect(ansisql.ANSIDialect): - def __init__(self, module = None, **kwargs): - if module is None: - self.module = mysql - else: - self.module = module + def __init__(self, **kwargs): ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) def create_connect_args(self, url): @@ -305,14 +296,18 @@ class MySQLDialect(ansisql.ANSIDialect): # TODO: what about options like "ssl", "cursorclass" and "conv" ? client_flag = opts.get('client_flag', 0) - if CLIENT_FLAGS is not None: - client_flag |= CLIENT_FLAGS.FOUND_ROWS + if self.dbapi is not None: + try: + import MySQLdb.constants.CLIENT as CLIENT_FLAGS + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except: + pass opts['client_flag'] = client_flag return [[], opts] - def create_execution_context(self): - return MySQLExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return MySQLExecutionContext(self, *args, **kwargs) def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) @@ -324,10 +319,10 @@ class MySQLDialect(ansisql.ANSIDialect): return MySQLCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return MySQLSchemaGenerator(*args, **kwargs) + return MySQLSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return MySQLSchemaDropper(*args, **kwargs) + return MySQLSchemaDropper(self, *args, **kwargs) def preparer(self): return MySQLIdentifierPreparer(self) @@ -337,14 +332,14 @@ class MySQLDialect(ansisql.ANSIDialect): rowcount = cursor.executemany(statement, parameters) if context is not None: context._rowcount = rowcount - except mysql.OperationalError, o: + except self.dbapi.OperationalError, o: if o.args[0] == 2006 or o.args[0] == 2014: cursor.invalidate() raise o def do_execute(self, cursor, statement, parameters, **kwargs): try: cursor.execute(statement, parameters) - except mysql.OperationalError, o: + except self.dbapi.OperationalError, o: if o.args[0] == 2006 or o.args[0] == 2014: cursor.invalidate() raise o @@ -361,11 +356,9 @@ class MySQLDialect(ansisql.ANSIDialect): self._default_schema_name = text("select database()", self).scalar() return self._default_schema_name - def dbapi(self): - return self.module - def has_table(self, connection, table_name, schema=None): - cursor = connection.execute("show table status like '" + table_name + "'") + cursor = connection.execute("show table status like %s", [table_name]) + print "CURSOR", cursor, "ROWCOUNT", cursor.rowcount, "REAL RC", cursor.cursor.rowcount return bool( not not cursor.rowcount ) def reflecttable(self, connection, table): @@ -492,8 +485,7 @@ class MySQLCompiler(ansisql.ANSICompiler): class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): - t = column.type.engine_impl(self.engine) - colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index adea127bf..5377759a2 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -8,15 +8,13 @@ import sys, StringIO, string, re from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging -import sqlalchemy.engine.default as default +from sqlalchemy.engine import default, base import sqlalchemy.types as sqltypes -try: +def dbapi(): import cx_Oracle -except: - cx_Oracle = None + return cx_Oracle -ORACLE_BINARY_TYPES = [getattr(cx_Oracle, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(cx_Oracle, k)] class OracleNumeric(sqltypes.Numeric): def get_col_spec(self): @@ -149,26 +147,32 @@ def descriptor(): ]} class OracleExecutionContext(default.DefaultExecutionContext): - def pre_exec(self, engine, proxy, compiled, parameters): - super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters) + def pre_exec(self): + super(OracleExecutionContext, self).pre_exec() if self.dialect.auto_setinputsizes: - self.set_input_sizes(proxy(), parameters) + self.set_input_sizes() + + def get_result_proxy(self): + if self.cursor.description is not None: + for column in self.cursor.description: + type_code = column[1] + if type_code in self.dialect.ORACLE_BINARY_TYPES: + return base.BufferedColumnResultProxy(self) + + return base.ResultProxy(self) class OracleDialect(ansisql.ANSIDialect): - def __init__(self, use_ansi=True, auto_setinputsizes=True, module=None, threaded=True, **kwargs): + def __init__(self, use_ansi=True, auto_setinputsizes=True, threaded=True, **kwargs): + ansisql.ANSIDialect.__init__(self, default_paramstyle='named', **kwargs) self.use_ansi = use_ansi self.threaded = threaded - if module is None: - self.module = cx_Oracle - else: - self.module = module - self.supports_timestamp = hasattr(self.module, 'TIMESTAMP' ) + self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) self.auto_setinputsizes = auto_setinputsizes - ansisql.ANSIDialect.__init__(self, **kwargs) - - def dbapi(self): - return self.module - + if self.dbapi is not None: + self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)] + else: + self.ORACLE_BINARY_TYPES = [] + def create_connect_args(self, url): if url.database: # if we have a database, then we have a remote host @@ -177,7 +181,7 @@ class OracleDialect(ansisql.ANSIDialect): port = int(port) else: port = 1521 - dsn = self.module.makedsn(url.host,port,url.database) + dsn = self.dbapi.makedsn(url.host,port,url.database) else: # we have a local tnsname dsn = url.host @@ -206,20 +210,20 @@ class OracleDialect(ansisql.ANSIDialect): else: return "rowid" - def create_execution_context(self): - return OracleExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return OracleExecutionContext(self, *args, **kwargs) def compiler(self, statement, bindparams, **kwargs): return OracleCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return OracleSchemaGenerator(*args, **kwargs) + return OracleSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return OracleSchemaDropper(*args, **kwargs) + return OracleSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return OracleDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return OracleDefaultRunner(connection, **kwargs) def has_table(self, connection, table_name, schema=None): cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()}) @@ -405,15 +409,6 @@ class OracleDialect(ansisql.ANSIDialect): if context is not None: context._rowcount = rowcount - def create_result_proxy_args(self, connection, cursor): - args = super(OracleDialect, self).create_result_proxy_args(connection, cursor) - if cursor and cursor.description: - for column in cursor.description: - type_code = column[1] - if type_code in ORACLE_BINARY_TYPES: - args['should_prefetch'] = True - break - return args OracleDialect.logger = logging.class_logger(OracleDialect) @@ -569,7 +564,7 @@ class OracleCompiler(ansisql.ANSICompiler): class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.engine_impl(self.engine).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -579,22 +574,22 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not self.engine.dialect.has_sequence(self.connection, sequence.name): + if not self.dialect.has_sequence(self.connection, sequence.name): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class OracleSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if self.engine.dialect.has_sequence(self.connection, sequence.name): + if self.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() class OracleDefaultRunner(ansisql.ANSIDefaultRunner): def exec_default_sql(self, default): c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile() - return self.proxy(str(c), c.get_params()).fetchone()[0] + return self.connection.execute_compiled(c).scalar() def visit_sequence(self, seq): - return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0] + return self.connection.execute_text("SELECT " + seq.name + ".nextval FROM DUAL").scalar() dialect = OracleDialect diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d83607793..2943d163e 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -4,33 +4,28 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import datetime, sys, StringIO, string, types, re - -import sqlalchemy.util as util -import sqlalchemy.sql as sql -import sqlalchemy.engine as engine -import sqlalchemy.engine.default as default -import sqlalchemy.schema as schema -import sqlalchemy.ansisql as ansisql +import datetime, string, types, re, random + +from sqlalchemy import util, sql, schema, ansisql, exceptions +from sqlalchemy.engine import base, default import sqlalchemy.types as sqltypes -import sqlalchemy.exceptions as exceptions from sqlalchemy.databases import information_schema as ischema -import re try: import mx.DateTime.DateTime as mxDateTime except: mxDateTime = None -try: - import psycopg2 as psycopg - #import psycopg2.psycopg1 as psycopg -except: +def dbapi(): try: - import psycopg - except: - psycopg = None - + import psycopg2 as psycopg + except ImportError, e: + try: + import psycopg + except ImportError, e2: + raise e + return psycopg + class PGInet(sqltypes.TypeEngine): def get_col_spec(self): return "INET" @@ -74,8 +69,8 @@ class PG1DateTime(sqltypes.DateTime): mx_datetime = mxDateTime(value.year, value.month, value.day, value.hour, value.minute, seconds) - return psycopg.TimestampFromMx(mx_datetime) - return psycopg.TimestampFromMx(value) + return dialect.dbapi.TimestampFromMx(mx_datetime) + return dialect.dbapi.TimestampFromMx(value) else: return None @@ -101,7 +96,7 @@ class PG1Date(sqltypes.Date): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: - return psycopg.DateFromMx(value) + return dialect.dbapi.DateFromMx(value) else: return None @@ -219,44 +214,49 @@ def descriptor(): ]} class PGExecutionContext(default.DefaultExecutionContext): - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: - if not engine.dialect.use_oids: + + def is_select(self): + return re.match(r'SELECT', self.statement.lstrip(), re.I) and not re.search(r'FOR UPDATE\s*$', self.statement, re.I) + + def create_cursor(self): + if self.dialect.server_side_cursors and self.is_select(): + # use server-side cursors: + # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c" + hex(random.randint(0, 65535))[2:] + return self.connection.connection.cursor(ident) + else: + return self.connection.connection.cursor() + + def get_result_proxy(self): + if self.dialect.server_side_cursors and self.is_select(): + return base.BufferedRowResultProxy(self) + else: + return base.ResultProxy(self) + + def post_exec(self): + if self.compiled.isinsert and self.last_inserted_ids is None: + if not self.dialect.use_oids: pass # will raise invalid error when they go to get them else: - table = compiled.statement.table - cursor = proxy() - if cursor.lastrowid is not None and table is not None and len(table.primary_key): - s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) - c = s.compile(engine=engine) - cursor = proxy(str(c), c.get_params()) - row = cursor.fetchone() + table = self.compiled.statement.table + if self.cursor.lastrowid is not None and table is not None and len(table.primary_key): + s = sql.select(table.primary_key, table.oid_column == self.cursor.lastrowid) + row = self.connection.execute(s).fetchone() self._last_inserted_ids = [v for v in row] - + super(PGExecutionContext, self).post_exec() + class PGDialect(ansisql.ANSIDialect): - def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params): + def __init__(self, use_oids=False, use_information_schema=False, server_side_cursors=False, **kwargs): + ansisql.ANSIDialect.__init__(self, default_paramstyle='pyformat', **kwargs) self.use_oids = use_oids self.server_side_cursors = server_side_cursors - if module is None: - #if psycopg is None: - # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") - self.module = psycopg + if self.dbapi is None or not hasattr(self.dbapi, '__version__') or self.dbapi.__version__.startswith('2'): + self.version = 2 else: - self.module = module - # figure psycopg version 1 or 2 - try: - if self.module.__version__.startswith('2'): - self.version = 2 - else: - self.version = 1 - except: self.version = 1 - ansisql.ANSIDialect.__init__(self, **params) self.use_information_schema = use_information_schema - # produce consistent paramstyle even if psycopg2 module not present - if self.module is None: - self.paramstyle = 'pyformat' + self.paramstyle = 'pyformat' def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) @@ -268,16 +268,9 @@ class PGDialect(ansisql.ANSIDialect): opts.update(url.query) return ([], opts) - def create_cursor(self, connection): - if self.server_side_cursors: - # use server-side cursors: - # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - return connection.cursor('x') - else: - return connection.cursor() - def create_execution_context(self): - return PGExecutionContext(self) + def create_execution_context(self, *args, **kwargs): + return PGExecutionContext(self, *args, **kwargs) def max_identifier_length(self): return 68 @@ -292,13 +285,13 @@ class PGDialect(ansisql.ANSIDialect): return PGCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return PGSchemaGenerator(*args, **kwargs) + return PGSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return PGSchemaDropper(*args, **kwargs) + return PGSchemaDropper(self, *args, **kwargs) - def defaultrunner(self, engine, proxy): - return PGDefaultRunner(engine, proxy) + def defaultrunner(self, connection, **kwargs): + return PGDefaultRunner(connection, **kwargs) def preparer(self): return PGIdentifierPreparer(self) @@ -326,7 +319,6 @@ class PGDialect(ansisql.ANSIDialect): ``psycopg2`` is not nice enough to produce this correctly for an executemany, so we do our own executemany here. """ - rowcount = 0 for param in parameters: c.execute(statement, param) @@ -334,9 +326,6 @@ class PGDialect(ansisql.ANSIDialect): if context is not None: context._rowcount = rowcount - def dbapi(self): - return self.module - def has_table(self, connection, table_name, schema=None): # seems like case gets folded in pg_class... if schema is None: @@ -542,7 +531,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): else: colspec += " SERIAL" else: - colspec += " " + column.type.engine_impl(self.engine).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -567,8 +556,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): - c = self.proxy("select %s" % column.default.arg) - return c.fetchone()[0] + return self.connection.execute_text("select %s" % column.default.arg).scalar() elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting @@ -577,17 +565,13 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner): exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) - c = self.proxy(exc) - return c.fetchone()[0] - else: - return ansisql.ANSIDefaultRunner.get_column_default(self, column) - else: - return ansisql.ANSIDefaultRunner.get_column_default(self, column) + return self.connection.execute_text(exc).scalar() + + return super(ansisql.ANSIDefaultRunner, self).get_column_default(column) def visit_sequence(self, seq): if not seq.optional: - c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq)) - return c.fetchone()[0] + return self.connection.execute("select nextval('%s')" % self.dialect.identifier_preparer.format_sequence(seq)).scalar() else: return None diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index b29be9eed..9270f2a5f 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -12,19 +12,19 @@ import sqlalchemy.engine.default as default import sqlalchemy.types as sqltypes import datetime,time -pysqlite2_timesupport = False # Change this if the init.d guys ever get around to supporting time cols - -try: - from pysqlite2 import dbapi2 as sqlite -except ImportError: +def dbapi(): try: - from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. - except ImportError: + from pysqlite2 import dbapi2 as sqlite + except ImportError, e: try: - sqlite = __import__('sqlite') # skip ourselves - except: - sqlite = None - + from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. + except ImportError: + try: + sqlite = __import__('sqlite') # skip ourselves + except ImportError: + raise e + return sqlite + class SLNumeric(sqltypes.Numeric): def get_col_spec(self): if self.precision is None: @@ -140,10 +140,6 @@ pragma_names = { 'BLOB' : SLBinary, } -if pysqlite2_timesupport: - colspecs.update({sqltypes.Time : SLTime}) - pragma_names.update({'TIME' : SLTime}) - def descriptor(): return {'name':'sqlite', 'description':'SQLite', @@ -152,25 +148,29 @@ def descriptor(): ]} class SQLiteExecutionContext(default.DefaultExecutionContext): - def post_exec(self, engine, proxy, compiled, parameters, **kwargs): - if getattr(compiled, "isinsert", False): - self._last_inserted_ids = [proxy().lastrowid] - + def post_exec(self): + if self.compiled.isinsert: + self._last_inserted_ids = [self.cursor.lastrowid] + super(SQLiteExecutionContext, self).post_exec() + class SQLiteDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): + ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs) def vers(num): return tuple([int(x) for x in num.split('.')]) - self.supports_cast = (sqlite is not None and vers(sqlite.sqlite_version) >= vers("3.2.3")) - ansisql.ANSIDialect.__init__(self, **kwargs) + self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) def compiler(self, statement, bindparams, **kwargs): return SQLiteCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): - return SQLiteSchemaGenerator(*args, **kwargs) + return SQLiteSchemaGenerator(self, *args, **kwargs) def schemadropper(self, *args, **kwargs): - return SQLiteSchemaDropper(*args, **kwargs) + return SQLiteSchemaDropper(self, *args, **kwargs) + + def supports_alter(self): + return False def preparer(self): return SQLiteIdentifierPreparer(self) @@ -182,8 +182,8 @@ class SQLiteDialect(ansisql.ANSIDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) - def create_execution_context(self): - return SQLiteExecutionContext(self) + def create_execution_context(self, **kwargs): + return SQLiteExecutionContext(self, **kwargs) def last_inserted_ids(self): return self.context.last_inserted_ids @@ -191,9 +191,6 @@ class SQLiteDialect(ansisql.ANSIDialect): def oid_column_name(self, column): return "oid" - def dbapi(self): - return sqlite - def has_table(self, connection, table_name, schema=None): cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {}) row = cursor.fetchone() @@ -321,11 +318,9 @@ class SQLiteCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.binary_operator_string(self, binary) class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): - def supports_alter(self): - return False def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -345,8 +340,7 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) class SQLiteSchemaDropper(ansisql.ANSISchemaDropper): - def supports_alter(self): - return False + pass class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def __init__(self, dialect): |
