summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
commit8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch)
treeae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/engine/default.py
parent7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff)
downloadsqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py230
1 files changed, 182 insertions, 48 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 728b932a2..935d1e087 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -13,36 +13,59 @@ as the base class for their own corresponding classes.
"""
import re, random
-from sqlalchemy.engine import base
+from sqlalchemy.engine import base, reflection
from sqlalchemy.sql import compiler, expression
-from sqlalchemy import exc
+from sqlalchemy import exc, types as sqltypes, util
AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
re.I | re.UNICODE)
+
class DefaultDialect(base.Dialect):
"""Default implementation of Dialect"""
- name = 'default'
- schemagenerator = compiler.SchemaGenerator
- schemadropper = compiler.SchemaDropper
- statement_compiler = compiler.DefaultCompiler
+ statement_compiler = compiler.SQLCompiler
+ ddl_compiler = compiler.DDLCompiler
+ type_compiler = compiler.GenericTypeCompiler
preparer = compiler.IdentifierPreparer
defaultrunner = base.DefaultRunner
supports_alter = True
+
+ supports_sequences = False
+ sequences_optional = False
+ preexecute_autoincrement_sequences = False
+ postfetch_lastrowid = True
+ implicit_returning = False
+
+ # Py3K
+ #supports_unicode_statements = True
+ #supports_unicode_binds = True
+ # Py2K
supports_unicode_statements = False
+ supports_unicode_binds = False
+ # end Py2K
+
+ name = 'default'
max_identifier_length = 9999
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
- preexecute_pk_sequences = False
- supports_pk_autoincrement = True
dbapi_type_map = {}
default_paramstyle = 'named'
- supports_default_values = False
+ supports_default_values = False
supports_empty_insert = True
+
+ # indicates symbol names are
+ # UPPERCASEd if they are case insensitive
+ # within the database.
+ # if this is True, the methods normalize_name()
+ # and denormalize_name() must be provided.
+ requires_name_normalize = False
+
+ reflection_options = ()
def __init__(self, convert_unicode=False, assert_unicode=False,
- encoding='utf-8', paramstyle=None, dbapi=None,
+ encoding='utf-8', paramstyle=None, dbapi=None,
+ implicit_returning=None,
label_length=None, **kwargs):
self.convert_unicode = convert_unicode
self.assert_unicode = assert_unicode
@@ -56,28 +79,58 @@ class DefaultDialect(base.Dialect):
self.paramstyle = self.dbapi.paramstyle
else:
self.paramstyle = self.default_paramstyle
+ if implicit_returning is not None:
+ self.implicit_returning = implicit_returning
self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
self.identifier_preparer = self.preparer(self)
+ self.type_compiler = self.type_compiler(self)
+
if label_length and label_length > self.max_identifier_length:
- raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length))
+ raise exc.ArgumentError("Label length of %d is greater than this dialect's"
+ " maximum identifier length of %d" %
+ (label_length, self.max_identifier_length))
self.label_length = label_length
- self.description_encoding = getattr(self, 'description_encoding', encoding)
- def type_descriptor(self, typeobj):
+ if not hasattr(self, 'description_encoding'):
+ self.description_encoding = getattr(self, 'description_encoding', encoding)
+
+ # Py3K
+ ## work around dialects that might change these values
+ #self.supports_unicode_statements = True
+ #self.supports_unicode_binds = True
+
+ def initialize(self, connection):
+ if hasattr(self, '_get_server_version_info'):
+ self.server_version_info = self._get_server_version_info(connection)
+ if hasattr(self, '_get_default_schema_name'):
+ self.default_schema_name = self._get_default_schema_name(connection)
+
+ @classmethod
+ def type_descriptor(cls, typeobj):
"""Provide a database-specific ``TypeEngine`` object, given
the generic object which comes from the types module.
- Subclasses will usually use the ``adapt_type()`` method in the
- types module to make this job easy."""
+ This method looks for a dictionary called
+ ``colspecs`` as a class or instance-level variable,
+ and passes on to ``types.adapt_type()``.
- if type(typeobj) is type:
- typeobj = typeobj()
- return typeobj
+ """
+ return sqltypes.adapt_type(typeobj, cls.colspecs)
+
+ def reflecttable(self, connection, table, include_columns):
+ insp = reflection.Inspector.from_engine(connection)
+ return insp.reflecttable(table, include_columns)
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
- raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length))
-
+ raise exc.IdentifierError(
+ "Identifier '%s' exceeds maximum length of %d characters" %
+ (ident, self.max_identifier_length)
+ )
+
+ def connect(self, *cargs, **cparams):
+ return self.dbapi.connect(*cargs, **cparams)
+
def do_begin(self, connection):
"""Implementations might want to put logic here for turning
autocommit on/off, etc.
@@ -103,7 +156,8 @@ class DefaultDialect(base.Dialect):
"""Create a random two-phase transaction ID.
This id will be passed to do_begin_twophase(), do_rollback_twophase(),
- do_commit_twophase(). Its format is unspecified."""
+ do_commit_twophase(). Its format is unspecified.
+ """
return "_sa_%032x" % random.randint(0, 2 ** 128)
@@ -127,13 +181,30 @@ class DefaultDialect(base.Dialect):
class DefaultExecutionContext(base.ExecutionContext):
- def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None):
+
+ def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None):
self.dialect = dialect
self._connection = self.root_connection = connection
- self.compiled = compiled
self.engine = connection.engine
- if compiled is not None:
+ if compiled_ddl is not None:
+ self.compiled = compiled = compiled_ddl
+ if not dialect.supports_unicode_statements:
+ self.statement = unicode(compiled).encode(self.dialect.encoding)
+ else:
+ self.statement = unicode(compiled)
+ self.isinsert = self.isupdate = self.isdelete = self.executemany = False
+ self.should_autocommit = True
+ self.result_map = None
+ self.cursor = self.create_cursor()
+ self.compiled_parameters = []
+ if self.dialect.positional:
+ self.parameters = [()]
+ else:
+ self.parameters = [{}]
+ elif compiled_sql is not None:
+ self.compiled = compiled = compiled_sql
+
# compiled clauseelement. process bind params, process table defaults,
# track collections used by ResultProxy to target and process results
@@ -156,6 +227,7 @@ class DefaultExecutionContext(base.ExecutionContext):
self.isinsert = compiled.isinsert
self.isupdate = compiled.isupdate
+ self.isdelete = compiled.isdelete
self.should_autocommit = compiled.statement._autocommit
if isinstance(compiled.statement, expression._TextClause):
self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement)
@@ -173,31 +245,43 @@ class DefaultExecutionContext(base.ExecutionContext):
self.parameters = self.__convert_compiled_params(self.compiled_parameters)
elif statement is not None:
- # plain text statement.
- self.result_map = None
+ # plain text statement
+ self.result_map = self.compiled = None
self.parameters = self.__encode_param_keys(parameters)
self.executemany = len(parameters) > 1
if isinstance(statement, unicode) and not dialect.supports_unicode_statements:
self.statement = statement.encode(self.dialect.encoding)
else:
self.statement = statement
- self.isinsert = self.isupdate = False
+ self.isinsert = self.isupdate = self.isdelete = False
self.cursor = self.create_cursor()
self.should_autocommit = self.should_autocommit_text(statement)
else:
# no statement. used for standalone ColumnDefault execution.
- self.statement = None
- self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False
+ self.statement = self.compiled = None
+ self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False
self.cursor = self.create_cursor()
-
+
+ @util.memoized_property
+ def _is_explicit_returning(self):
+ return self.compiled and \
+ getattr(self.compiled.statement, '_returning', False)
+
+ @util.memoized_property
+ def _is_implicit_returning(self):
+ return self.compiled and \
+ bool(self.compiled.returning) and \
+ not self.compiled.statement._returning
+
@property
def connection(self):
return self._connection._branch()
def __encode_param_keys(self, params):
- """apply string encoding to the keys of dictionary-based bind parameters.
+ """Apply string encoding to the keys of dictionary-based bind parameters.
- This is only used executing textual, non-compiled SQL expressions."""
+ This is only used executing textual, non-compiled SQL expressions.
+ """
if self.dialect.positional or self.dialect.supports_unicode_statements:
if params:
@@ -216,7 +300,7 @@ class DefaultExecutionContext(base.ExecutionContext):
return [proc(d) for d in params] or [{}]
def __convert_compiled_params(self, compiled_parameters):
- """convert the dictionary of bind parameter values into a dict or list
+ """Convert the dictionary of bind parameter values into a dict or list
to be sent to the DBAPI's execute() or executemany() method.
"""
@@ -263,26 +347,69 @@ class DefaultExecutionContext(base.ExecutionContext):
def post_exec(self):
pass
+ def get_lastrowid(self):
+ """return self.cursor.lastrowid, or equivalent, after an INSERT.
+
+ This may involve calling special cursor functions,
+ issuing a new SELECT on the cursor (or a new one),
+ or returning a stored value that was
+ calculated within post_exec().
+
+ This function will only be called for dialects
+ which support "implicit" primary key generation,
+ keep preexecute_autoincrement_sequences set to False,
+ and when no explicit id value was bound to the
+ statement.
+
+ The function is called once, directly after
+ post_exec() and before the transaction is committed
+ or ResultProxy is generated. If the post_exec()
+ method assigns a value to `self._lastrowid`, the
+ value is used in place of calling get_lastrowid().
+
+ Note that this method is *not* equivalent to the
+ ``lastrowid`` method on ``ResultProxy``, which is a
+ direct proxy to the DBAPI ``lastrowid`` accessor
+ in all cases.
+
+ """
+
+ return self.cursor.lastrowid
+
def handle_dbapi_exception(self, e):
pass
def get_result_proxy(self):
return base.ResultProxy(self)
+
+ @property
+ def rowcount(self):
+ return self.cursor.rowcount
- def get_rowcount(self):
- if hasattr(self, '_rowcount'):
- return self._rowcount
- else:
- return self.cursor.rowcount
-
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount
def supports_sane_multi_rowcount(self):
return self.dialect.supports_sane_multi_rowcount
-
- def last_inserted_ids(self):
- return self._last_inserted_ids
+
+ def post_insert(self):
+ if self.dialect.postfetch_lastrowid and \
+ (not len(self._inserted_primary_key) or \
+ None in self._inserted_primary_key):
+
+ table = self.compiled.statement.table
+ lastrowid = self.get_lastrowid()
+ self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
+ for c, v in zip(table.primary_key, self._inserted_primary_key)
+ ]
+
+ def _fetch_implicit_returning(self, resultproxy):
+ table = self.compiled.statement.table
+ row = resultproxy.first()
+
+ self._inserted_primary_key = [v is not None and v or row[c]
+ for c, v in zip(table.primary_key, self._inserted_primary_key)
+ ]
def last_inserted_params(self):
return self._last_inserted_params
@@ -293,12 +420,15 @@ class DefaultExecutionContext(base.ExecutionContext):
def lastrow_has_defaults(self):
return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
- def set_input_sizes(self):
+ def set_input_sizes(self, translate=None, exclude_types=None):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
"""
+ if not hasattr(self.compiled, 'bind_names'):
+ return
+
types = dict(
(self.compiled.bind_names[bindparam], bindparam.type)
for bindparam in self.compiled.bind_names)
@@ -308,7 +438,7 @@ class DefaultExecutionContext(base.ExecutionContext):
for key in self.compiled.positiontup:
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None:
+ if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*inputsizes)
@@ -320,7 +450,9 @@ class DefaultExecutionContext(base.ExecutionContext):
for key in self.compiled.bind_names.values():
typeengine = types[key]
dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
- if dbtype is not None:
+ if dbtype is not None and (not exclude_types or dbtype not in exclude_types):
+ if translate:
+ key = translate.get(key, key)
inputsizes[key.encode(self.dialect.encoding)] = dbtype
try:
self.cursor.setinputsizes(**inputsizes)
@@ -329,8 +461,9 @@ class DefaultExecutionContext(base.ExecutionContext):
raise
def __process_defaults(self):
- """generate default values for compiled insert/update statements,
- and generate last_inserted_ids() collection."""
+ """Generate default values for compiled insert/update statements,
+ and generate inserted_primary_key collection.
+ """
if self.executemany:
if len(self.compiled.prefetch):
@@ -364,7 +497,8 @@ class DefaultExecutionContext(base.ExecutionContext):
compiled_parameters[c.key] = val
if self.isinsert:
- self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key]
+ self._inserted_primary_key = [compiled_parameters.get(c.key, None)
+ for c in self.compiled.statement.table.primary_key]
self._last_inserted_params = compiled_parameters
else:
self._last_updated_params = compiled_parameters