diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
| commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
| tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/engine/default.py | |
| parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
| download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz | |
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 230 |
1 files changed, 182 insertions, 48 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 728b932a2..935d1e087 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -13,36 +13,59 @@ as the base class for their own corresponding classes. """ import re, random -from sqlalchemy.engine import base +from sqlalchemy.engine import base, reflection from sqlalchemy.sql import compiler, expression -from sqlalchemy import exc +from sqlalchemy import exc, types as sqltypes, util AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) + class DefaultDialect(base.Dialect): """Default implementation of Dialect""" - name = 'default' - schemagenerator = compiler.SchemaGenerator - schemadropper = compiler.SchemaDropper - statement_compiler = compiler.DefaultCompiler + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler = compiler.GenericTypeCompiler preparer = compiler.IdentifierPreparer defaultrunner = base.DefaultRunner supports_alter = True + + supports_sequences = False + sequences_optional = False + preexecute_autoincrement_sequences = False + postfetch_lastrowid = True + implicit_returning = False + + # Py3K + #supports_unicode_statements = True + #supports_unicode_binds = True + # Py2K supports_unicode_statements = False + supports_unicode_binds = False + # end Py2K + + name = 'default' max_identifier_length = 9999 supports_sane_rowcount = True supports_sane_multi_rowcount = True - preexecute_pk_sequences = False - supports_pk_autoincrement = True dbapi_type_map = {} default_paramstyle = 'named' - supports_default_values = False + supports_default_values = False supports_empty_insert = True + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + + reflection_options = () def __init__(self, convert_unicode=False, assert_unicode=False, - encoding='utf-8', paramstyle=None, dbapi=None, + encoding='utf-8', paramstyle=None, dbapi=None, + implicit_returning=None, label_length=None, **kwargs): self.convert_unicode = convert_unicode self.assert_unicode = assert_unicode @@ -56,28 +79,58 @@ class DefaultDialect(base.Dialect): self.paramstyle = self.dbapi.paramstyle else: self.paramstyle = self.default_paramstyle + if implicit_returning is not None: + self.implicit_returning = implicit_returning self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) + self.type_compiler = self.type_compiler(self) + if label_length and label_length > self.max_identifier_length: - raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length)) + raise exc.ArgumentError("Label length of %d is greater than this dialect's" + " maximum identifier length of %d" % + (label_length, self.max_identifier_length)) self.label_length = label_length - self.description_encoding = getattr(self, 'description_encoding', encoding) - def type_descriptor(self, typeobj): + if not hasattr(self, 'description_encoding'): + self.description_encoding = getattr(self, 'description_encoding', encoding) + + # Py3K + ## work around dialects that might change these values + #self.supports_unicode_statements = True + #self.supports_unicode_binds = True + + def initialize(self, connection): + if hasattr(self, '_get_server_version_info'): + self.server_version_info = self._get_server_version_info(connection) + if hasattr(self, '_get_default_schema_name'): + self.default_schema_name = self._get_default_schema_name(connection) + + @classmethod + def type_descriptor(cls, typeobj): """Provide a database-specific ``TypeEngine`` object, given the generic object which comes from the types module. - Subclasses will usually use the ``adapt_type()`` method in the - types module to make this job easy.""" + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to ``types.adapt_type()``. - if type(typeobj) is type: - typeobj = typeobj() - return typeobj + """ + return sqltypes.adapt_type(typeobj, cls.colspecs) + + def reflecttable(self, connection, table, include_columns): + insp = reflection.Inspector.from_engine(connection) + return insp.reflecttable(table, include_columns) def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: - raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length)) - + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" % + (ident, self.max_identifier_length) + ) + + def connect(self, *cargs, **cparams): + return self.dbapi.connect(*cargs, **cparams) + def do_begin(self, connection): """Implementations might want to put logic here for turning autocommit on/off, etc. @@ -103,7 +156,8 @@ class DefaultDialect(base.Dialect): """Create a random two-phase transaction ID. This id will be passed to do_begin_twophase(), do_rollback_twophase(), - do_commit_twophase(). Its format is unspecified.""" + do_commit_twophase(). Its format is unspecified. + """ return "_sa_%032x" % random.randint(0, 2 ** 128) @@ -127,13 +181,30 @@ class DefaultDialect(base.Dialect): class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + + def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None): self.dialect = dialect self._connection = self.root_connection = connection - self.compiled = compiled self.engine = connection.engine - if compiled is not None: + if compiled_ddl is not None: + self.compiled = compiled = compiled_ddl + if not dialect.supports_unicode_statements: + self.statement = unicode(compiled).encode(self.dialect.encoding) + else: + self.statement = unicode(compiled) + self.isinsert = self.isupdate = self.isdelete = self.executemany = False + self.should_autocommit = True + self.result_map = None + self.cursor = self.create_cursor() + self.compiled_parameters = [] + if self.dialect.positional: + self.parameters = [()] + else: + self.parameters = [{}] + elif compiled_sql is not None: + self.compiled = compiled = compiled_sql + # compiled clauseelement. process bind params, process table defaults, # track collections used by ResultProxy to target and process results @@ -156,6 +227,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete self.should_autocommit = compiled.statement._autocommit if isinstance(compiled.statement, expression._TextClause): self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement) @@ -173,31 +245,43 @@ class DefaultExecutionContext(base.ExecutionContext): self.parameters = self.__convert_compiled_params(self.compiled_parameters) elif statement is not None: - # plain text statement. - self.result_map = None + # plain text statement + self.result_map = self.compiled = None self.parameters = self.__encode_param_keys(parameters) self.executemany = len(parameters) > 1 if isinstance(statement, unicode) and not dialect.supports_unicode_statements: self.statement = statement.encode(self.dialect.encoding) else: self.statement = statement - self.isinsert = self.isupdate = False + self.isinsert = self.isupdate = self.isdelete = False self.cursor = self.create_cursor() self.should_autocommit = self.should_autocommit_text(statement) else: # no statement. used for standalone ColumnDefault execution. - self.statement = None - self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False + self.statement = self.compiled = None + self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False self.cursor = self.create_cursor() - + + @util.memoized_property + def _is_explicit_returning(self): + return self.compiled and \ + getattr(self.compiled.statement, '_returning', False) + + @util.memoized_property + def _is_implicit_returning(self): + return self.compiled and \ + bool(self.compiled.returning) and \ + not self.compiled.statement._returning + @property def connection(self): return self._connection._branch() def __encode_param_keys(self, params): - """apply string encoding to the keys of dictionary-based bind parameters. + """Apply string encoding to the keys of dictionary-based bind parameters. - This is only used executing textual, non-compiled SQL expressions.""" + This is only used executing textual, non-compiled SQL expressions. + """ if self.dialect.positional or self.dialect.supports_unicode_statements: if params: @@ -216,7 +300,7 @@ class DefaultExecutionContext(base.ExecutionContext): return [proc(d) for d in params] or [{}] def __convert_compiled_params(self, compiled_parameters): - """convert the dictionary of bind parameter values into a dict or list + """Convert the dictionary of bind parameter values into a dict or list to be sent to the DBAPI's execute() or executemany() method. """ @@ -263,26 +347,69 @@ class DefaultExecutionContext(base.ExecutionContext): def post_exec(self): pass + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, + issuing a new SELECT on the cursor (or a new one), + or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects + which support "implicit" primary key generation, + keep preexecute_autoincrement_sequences set to False, + and when no explicit id value was bound to the + statement. + + The function is called once, directly after + post_exec() and before the transaction is committed + or ResultProxy is generated. If the post_exec() + method assigns a value to `self._lastrowid`, the + value is used in place of calling get_lastrowid(). + + Note that this method is *not* equivalent to the + ``lastrowid`` method on ``ResultProxy``, which is a + direct proxy to the DBAPI ``lastrowid`` accessor + in all cases. + + """ + + return self.cursor.lastrowid + def handle_dbapi_exception(self, e): pass def get_result_proxy(self): return base.ResultProxy(self) + + @property + def rowcount(self): + return self.cursor.rowcount - def get_rowcount(self): - if hasattr(self, '_rowcount'): - return self._rowcount - else: - return self.cursor.rowcount - def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - - def last_inserted_ids(self): - return self._last_inserted_ids + + def post_insert(self): + if self.dialect.postfetch_lastrowid and \ + (not len(self._inserted_primary_key) or \ + None in self._inserted_primary_key): + + table = self.compiled.statement.table + lastrowid = self.get_lastrowid() + self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] + + def _fetch_implicit_returning(self, resultproxy): + table = self.compiled.statement.table + row = resultproxy.first() + + self._inserted_primary_key = [v is not None and v or row[c] + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] def last_inserted_params(self): return self._last_inserted_params @@ -293,12 +420,15 @@ class DefaultExecutionContext(base.ExecutionContext): def lastrow_has_defaults(self): return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) - def set_input_sizes(self): + def set_input_sizes(self, translate=None, exclude_types=None): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. """ + if not hasattr(self.compiled, 'bind_names'): + return + types = dict( (self.compiled.bind_names[bindparam], bindparam.type) for bindparam in self.compiled.bind_names) @@ -308,7 +438,7 @@ class DefaultExecutionContext(base.ExecutionContext): for key in self.compiled.positiontup: typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): inputsizes.append(dbtype) try: self.cursor.setinputsizes(*inputsizes) @@ -320,7 +450,9 @@ class DefaultExecutionContext(base.ExecutionContext): for key in self.compiled.bind_names.values(): typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): + if translate: + key = translate.get(key, key) inputsizes[key.encode(self.dialect.encoding)] = dbtype try: self.cursor.setinputsizes(**inputsizes) @@ -329,8 +461,9 @@ class DefaultExecutionContext(base.ExecutionContext): raise def __process_defaults(self): - """generate default values for compiled insert/update statements, - and generate last_inserted_ids() collection.""" + """Generate default values for compiled insert/update statements, + and generate inserted_primary_key collection. + """ if self.executemany: if len(self.compiled.prefetch): @@ -364,7 +497,8 @@ class DefaultExecutionContext(base.ExecutionContext): compiled_parameters[c.key] = val if self.isinsert: - self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key] + self._inserted_primary_key = [compiled_parameters.get(c.key, None) + for c in self.compiled.statement.table.primary_key] self._last_inserted_params = compiled_parameters else: self._last_updated_params = compiled_parameters |
