diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
| commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
| tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /lib/sqlalchemy/engine | |
| parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
| download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz | |
merge 0.6 series to trunk.
Diffstat (limited to 'lib/sqlalchemy/engine')
| -rw-r--r-- | lib/sqlalchemy/engine/__init__.py | 12 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 750 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/ddl.py | 128 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 230 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/reflection.py | 361 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/strategies.py | 90 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/threadlocal.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/url.py | 46 |
8 files changed, 1278 insertions, 343 deletions
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index bb2b1b5be..694a2f71f 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -50,7 +50,9 @@ url.py within a URL. """ -import sqlalchemy.databases +# not sure what this was used for +#import sqlalchemy.databases + from sqlalchemy.engine.base import ( BufferedColumnResultProxy, BufferedColumnRow, @@ -66,9 +68,9 @@ from sqlalchemy.engine.base import ( ResultProxy, RootTransaction, RowProxy, - SchemaIterator, Transaction, - TwoPhaseTransaction + TwoPhaseTransaction, + TypeCompiler ) from sqlalchemy.engine import strategies from sqlalchemy import util @@ -89,9 +91,9 @@ __all__ = ( 'ResultProxy', 'RootTransaction', 'RowProxy', - 'SchemaIterator', 'Transaction', 'TwoPhaseTransaction', + 'TypeCompiler', 'create_engine', 'engine_from_config', ) @@ -108,7 +110,7 @@ def create_engine(*args, **kwargs): The URL is a string in the form ``dialect://user:password@host/dbname[?key=value..]``, where - ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgres``, + ``dialect`` is a name such as ``mysql``, ``oracle``, ``postgresql``, etc. Alternatively, the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 39085c359..0a0b0ff0c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -10,14 +10,16 @@ Defines the basic components used to interface DB-API modules with higher-level statement-construction, connection-management, execution and result contexts. - """ -__all__ = ['BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', 'Compiled', 'Connectable', - 'Connection', 'DefaultRunner', 'Dialect', 'Engine', 'ExecutionContext', 'NestedTransaction', 'ResultProxy', - 'RootTransaction', 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', 'connection_memoize'] +__all__ = [ + 'BufferedColumnResultProxy', 'BufferedColumnRow', 'BufferedRowResultProxy', + 'Compiled', 'Connectable', 'Connection', 'DefaultRunner', 'Dialect', 'Engine', + 'ExecutionContext', 'NestedTransaction', 'ResultProxy', 'RootTransaction', + 'RowProxy', 'SchemaIterator', 'StringIO', 'Transaction', 'TwoPhaseTransaction', + 'connection_memoize'] -import inspect, StringIO +import inspect, StringIO, sys from sqlalchemy import exc, schema, util, types, log from sqlalchemy.sql import expression @@ -32,10 +34,14 @@ class Dialect(object): ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. All Dialects implement the following attributes: - + name - identifying name for the dialect (i.e. 'sqlite') - + identifying name for the dialect from a DBAPI-neutral point of view + (i.e. 'sqlite') + + driver + identifying name for the dialect's DBAPI + positional True if the paramstyle for this Dialect is positional. @@ -51,20 +57,25 @@ class Dialect(object): type of encoding to use for unicode, usually defaults to 'utf-8'. - schemagenerator - a :class:`~sqlalchemy.schema.SchemaVisitor` class which generates - schemas. - - schemadropper - a :class:`~sqlalchemy.schema.SchemaVisitor` class which drops schemas. - defaultrunner a :class:`~sqlalchemy.schema.SchemaVisitor` class which executes defaults. statement_compiler - a :class:`~sqlalchemy.engine.base.Compiled` class used to compile SQL - statements + a :class:`~Compiled` class used to compile SQL statements + + ddl_compiler + a :class:`~Compiled` class used to compile DDL statements + + server_version_info + a tuple containing a version number for the DB backend in use. + This value is only available for supporting dialects, and only for + a dialect that's been associated with a connection pool via + create_engine() or otherwise had its ``initialize()`` method called + with a conneciton. + + execution_ctx_cls + a :class:`ExecutionContext` class used to handle statement execution preparer a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to @@ -77,27 +88,38 @@ class Dialect(object): The maximum length of identifier names. supports_unicode_statements - Indicate whether the DB-API can receive SQL statements as Python unicode strings + Indicate whether the DB-API can receive SQL statements as Python + unicode strings + + supports_unicode_binds + Indicate whether the DB-API can receive string bind parameters + as Python unicode strings supports_sane_rowcount - Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements. + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements. supports_sane_multi_rowcount - Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements - when executed via executemany. - - preexecute_pk_sequences - Indicate if the dialect should pre-execute sequences on primary key - columns during an INSERT, if it's desired that the new row's primary key - be available after execution. - - supports_pk_autoincrement - Indicates if the dialect should allow the database to passively assign - a primary key column value. - + Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements when executed via + executemany. + + preexecute_autoincrement_sequences + True if 'implicit' primary key functions must be executed separately + in order to get their value. This is currently oriented towards + Postgresql. + + implicit_returning + use RETURNING or equivalent during INSERT execution in order to load + newly generated primary keys and other column defaults in one execution, + which are then available via inserted_primary_key. + If an insert statement has returning() specified explicitly, + the "implicit" functionality is not used and inserted_primary_key + will not be available. + dbapi_type_map A mapping of DB-API type objects present in this Dialect's - DB-API implmentation mapped to TypeEngine implementations used + DB-API implementation mapped to TypeEngine implementations used by the dialect. This is used to apply types to result sets based on the DB-API @@ -105,13 +127,15 @@ class Dialect(object): result sets against textual statements where no explicit typemap was present. - supports_default_values - Indicates if the construct ``INSERT INTO tablename DEFAULT VALUES`` is supported + colspecs + A dictionary of TypeEngine classes from sqlalchemy.types mapped + to subclasses that are specific to the dialect class. This + dictionary is class-level only and is not accessed from the + dialect instance itself. - description_encoding - type of encoding to use for unicode when working with metadata - descriptions. If set to ``None`` no encoding will be done. - This usually defaults to 'utf-8'. + supports_default_values + Indicates if the construct ``INSERT INTO tablename DEFAULT + VALUES`` is supported """ def create_connect_args(self, url): @@ -124,25 +148,28 @@ class Dialect(object): raise NotImplementedError() + @classmethod + def type_descriptor(cls, typeobj): + """Transform a generic type to a dialect-specific type. - def type_descriptor(self, typeobj): - """Transform a generic type to a database-specific type. - - Transforms the given :class:`~sqlalchemy.types.TypeEngine` instance - from generic to database-specific. - - Subclasses will usually use the + Dialect classes will usually use the :func:`~sqlalchemy.types.adapt_type` method in the types module to make this job easy. + + The returned result is cached *per dialect class* so can + contain no dialect-instance state. """ raise NotImplementedError() + def initialize(self, connection): + """Called during strategized creation of the dialect with a connection. - def server_version_info(self, connection): - """Return a tuple of the database's version number.""" + Allows dialects to configure options based on server version info or + other properties. + """ - raise NotImplementedError() + pass def reflecttable(self, connection, table, include_columns=None): """Load table description from the database. @@ -156,6 +183,133 @@ class Dialect(object): raise NotImplementedError() + def get_columns(self, connection, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return column + information as a list of dictionaries with these keys: + + name + the column's name + + type + [sqlalchemy.types#TypeEngine] + + nullable + boolean + + default + the column's default value + + autoincrement + boolean + + sequence + a dictionary of the form + {'name' : str, 'start' :int, 'increment': int} + + Additional column attributes may be present. + """ + + raise NotImplementedError() + + def get_primary_keys(self, connection, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return primary + key information as a list of column names. + """ + + raise NotImplementedError() + + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name`, and an optional string `schema`, return foreign + key information as a list of dicts with these keys: + + name + the constraint's name + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + """ + + raise NotImplementedError() + + def get_table_names(self, connection, schema=None, **kw): + """Return a list of table names for `schema`.""" + + raise NotImplementedError + + def get_view_names(self, connection, schema=None, **kw): + """Return a list of all view names available in the database. + + schema: + Optional, retrieve names from a non-default schema. + """ + + raise NotImplementedError() + + def get_view_definition(self, connection, view_name, schema=None, **kw): + """Return view definition. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `view_name`, and an optional string `schema`, return the view + definition. + """ + + raise NotImplementedError() + + def get_indexes(self, connection, table_name, schema=None, **kw): + """Return information about indexes in `table_name`. + + Given a :class:`~sqlalchemy.engine.Connection`, a string + `table_name` and an optional string `schema`, return index + information as a list of dictionaries with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + """ + + raise NotImplementedError() + + def normalize_name(self, name): + """convert the given name to lowercase if it is detected as case insensitive. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + + def denormalize_name(self, name): + """convert the given name to a case insensitive identifier for the backend + if it is an all-lowercase name. + + this method is only used if the dialect defines requires_name_normalize=True. + + """ + raise NotImplementedError() + def has_table(self, connection, table_name, schema=None): """Check the existence of a particular table in the database. @@ -178,7 +332,11 @@ class Dialect(object): raise NotImplementedError() def get_default_schema_name(self, connection): - """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`.""" + """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`. + + DEPRECATED. moving this towards dialect.default_schema_name (not complete). + + """ raise NotImplementedError() @@ -262,11 +420,14 @@ class Dialect(object): raise NotImplementedError() + def visit_pool(self, pool): + """Executed after a pool is created.""" + class ExecutionContext(object): """A messenger object for a Dialect that corresponds to a single execution. - ExecutionContext should have these datamembers: + ExecutionContext should have these data members: connection Connection object which can be freely used by default value @@ -308,20 +469,19 @@ class ExecutionContext(object): True if the statement is an UPDATE. should_autocommit - True if the statement is a "committable" statement + True if the statement is a "committable" statement. postfetch_cols - a list of Column objects for which a server-side default - or inline SQL expression value was fired off. applies to inserts and updates. - - + a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates. """ def create_cursor(self): """Return a new cursor generated from this ExecutionContext's connection. Some dialects may wish to change the behavior of - connection.cursor(), such as postgres which may return a PG + connection.cursor(), such as postgresql which may return a PG "server side" cursor. """ @@ -357,22 +517,11 @@ class ExecutionContext(object): def handle_dbapi_exception(self, e): """Receive a DBAPI exception which occured upon execute, result fetch, etc.""" - - raise NotImplementedError() - - def should_autocommit_text(self, statement): - """Parse the given textual statement and return True if it refers to a "committable" statement""" raise NotImplementedError() - def last_inserted_ids(self): - """Return the list of the primary key values for the last insert statement executed. - - This does not apply to straight textual clauses; only to - ``sql.Insert`` objects compiled against a ``schema.Table`` - object. The order of items in the list is the same as that of - the Table's 'primary_key' attribute. - """ + def should_autocommit_text(self, statement): + """Parse the given textual statement and return True if it refers to a "committable" statement""" raise NotImplementedError() @@ -401,7 +550,7 @@ class ExecutionContext(object): class Compiled(object): - """Represent a compiled SQL expression. + """Represent a compiled SQL or DDL expression. The ``__str__`` method of the ``Compiled`` object should produce the actual text of the statement. ``Compiled`` objects are @@ -413,53 +562,49 @@ class Compiled(object): defaults. """ - def __init__(self, dialect, statement, column_keys=None, bind=None): + def __init__(self, dialect, statement, bind=None): """Construct a new ``Compiled`` object. - dialect - ``Dialect`` to compile against. - - statement - ``ClauseElement`` to be compiled. + :param dialect: ``Dialect`` to compile against. - column_keys - a list of column names to be compiled into an INSERT or UPDATE - statement. + :param statement: ``ClauseElement`` to be compiled. - bind - Optional Engine or Connection to compile this statement against. - + :param bind: Optional Engine or Connection to compile this statement against. """ + self.dialect = dialect self.statement = statement - self.column_keys = column_keys self.bind = bind self.can_execute = statement.supports_execution def compile(self): """Produce the internal string representation of this element.""" - raise NotImplementedError() + self.string = self.process(self.statement) - def __str__(self): - """Return the string text of the generated SQL statement.""" + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) - raise NotImplementedError() + def __str__(self): + """Return the string text of the generated SQL or DDL.""" - @util.deprecated('Deprecated. Use construct_params(). ' - '(supports Unicode key names.)') - def get_params(self, **params): - return self.construct_params(params) + return self.string or '' - def construct_params(self, params): + def construct_params(self, params=None): """Return the bind params for this compiled object. - `params` is a dict of string/object pairs whos - values will override bind values compiled in - to the statement. + :param params: a dict of string/object pairs whos values will + override bind values compiled in to the + statement. """ + raise NotImplementedError() + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + def execute(self, *multiparams, **params): """Execute this compiled object.""" @@ -474,12 +619,24 @@ class Compiled(object): return self.execute(*multiparams, **params).scalar() +class TypeCompiler(object): + """Produces DDL specification for TypeEngine objects.""" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_): + return type_._compiler_dispatch(self) + + class Connectable(object): """Interface for an object which supports execution of SQL constructs. - + The two implementations of ``Connectable`` are :class:`Connection` and :class:`Engine`. - + + Connectable must also implement the 'dialect' member which references a + :class:`Dialect` instance. """ def contextual_connect(self): @@ -503,6 +660,7 @@ class Connectable(object): def _execute_clauseelement(self, elem, multiparams=None, params=None): raise NotImplementedError() + class Connection(Connectable): """Provides high-level functionality for a wrapped DB-API connection. @@ -514,7 +672,6 @@ class Connection(Connectable): .. index:: single: thread safety; Connection - """ def __init__(self, engine, connection=None, close_with_result=False, @@ -524,7 +681,6 @@ class Connection(Connectable): Connection objects are typically constructed by an :class:`~sqlalchemy.engine.Engine`, see the ``connect()`` and ``contextual_connect()`` methods of Engine. - """ self.engine = engine @@ -534,7 +690,7 @@ class Connection(Connectable): self.__savepoint_seq = 0 self.__branch = _branch self.__invalid = False - + def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, @@ -542,8 +698,8 @@ class Connection(Connectable): This is used to execute "sub" statements within a single execution, usually an INSERT statement. - """ + return self.engine.Connection(self.engine, self.__connection, _branch=True) @property @@ -554,13 +710,13 @@ class Connection(Connectable): @property def closed(self): - """return True if this connection is closed.""" + """Return True if this connection is closed.""" return not self.__invalid and '_Connection__connection' not in self.__dict__ @property def invalidated(self): - """return True if this connection was invalidated.""" + """Return True if this connection was invalidated.""" return self.__invalid @@ -583,13 +739,14 @@ class Connection(Connectable): def should_close_with_result(self): """Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode. - """ + return self.__close_with_result @property def info(self): """A collection of per-DB-API connection instance properties.""" + return self.connection.info def connect(self): @@ -598,8 +755,8 @@ class Connection(Connectable): This ``Connectable`` interface method returns self, allowing Connections to be used interchangably with Engines in most situations that require a bind. - """ + return self def contextual_connect(self, **kwargs): @@ -608,8 +765,8 @@ class Connection(Connectable): This ``Connectable`` interface method returns self, allowing Connections to be used interchangably with Engines in most situations that require a bind. - """ + return self def invalidate(self, exception=None): @@ -627,8 +784,8 @@ class Connection(Connectable): rolled back before a reconnect on this Connection can proceed. This is to prevent applications from accidentally continuing their transactional operations in a non-transactional state. - """ + if self.closed: raise exc.InvalidRequestError("This Connection is closed") @@ -651,8 +808,8 @@ class Connection(Connectable): :class:`~sqlalchemy.interfaces.PoolListener` for a mechanism to modify connection state when connections leave and return to their connection pool. - """ + self.__connection.detach() def begin(self): @@ -663,8 +820,8 @@ class Connection(Connectable): outermost transaction may ``commit``. Calls to ``commit`` on inner transactions are ignored. Any transaction in the hierarchy may ``rollback``, however. - """ + if self.__transaction is None: self.__transaction = RootTransaction(self) else: @@ -690,9 +847,8 @@ class Connection(Connectable): def begin_twophase(self, xid=None): """Begin a two-phase or XA transaction and return a Transaction handle. - xid - the two phase transaction id. If not supplied, a random id - will be generated. + :param xid: the two phase transaction id. If not supplied, a random id + will be generated. """ if self.__transaction is not None: @@ -813,9 +969,6 @@ class Connection(Connectable): return self.execute(object, *multiparams, **params).scalar() - def statement_compiler(self, statement, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) - def execute(self, object, *multiparams, **params): """Executes and returns a ResultProxy.""" @@ -826,11 +979,12 @@ class Connection(Connectable): raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object))) def __distill_params(self, multiparams, params): - """given arguments from the calling form *multiparams, **params, return a list + """Given arguments from the calling form *multiparams, **params, return a list of bind parameter structures, usually a list of dictionaries. - in the case of 'raw' execution which accepts positional parameters, - it may be a list of tuples or lists.""" + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + """ if not multiparams: if params: @@ -858,7 +1012,19 @@ class Connection(Connectable): return self._execute_clauseelement(func.select(), multiparams, params) def _execute_default(self, default, multiparams, params): - return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) + ret = self.engine.dialect.\ + defaultrunner(self.__create_execution_context()).\ + traverse_single(default) + if self.__close_with_result: + self.close() + return ret + + def _execute_ddl(self, ddl, params, multiparams): + context = self.__create_execution_context( + compiled_ddl=ddl.compile(dialect=self.dialect), + parameters=None + ) + return self.__execute_context(context) def _execute_clauseelement(self, elem, multiparams, params): params = self.__distill_params(multiparams, params) @@ -868,7 +1034,7 @@ class Connection(Connectable): keys = [] context = self.__create_execution_context( - compiled=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), + compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1), parameters=params ) return self.__execute_context(context) @@ -877,7 +1043,7 @@ class Connection(Connectable): """Execute a sql.Compiled object.""" context = self.__create_execution_context( - compiled=compiled, + compiled_sql=compiled, parameters=self.__distill_params(multiparams, params) ) return self.__execute_context(context) @@ -886,38 +1052,42 @@ class Connection(Connectable): parameters = self.__distill_params(multiparams, params) context = self.__create_execution_context(statement=statement, parameters=parameters) return self.__execute_context(context) - + def __execute_context(self, context): if context.compiled: context.pre_exec() + if context.executemany: self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context) else: self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context) + if context.compiled: context.post_exec() + + if context.isinsert and not context.executemany: + context.post_insert() + if context.should_autocommit and not self.in_transaction(): self._commit_impl() - return context.get_result_proxy() + + return context.get_result_proxy()._autoclose() - def _execute_ddl(self, ddl, params, multiparams): - if params: - schema_item, params = params[0], params[1:] - else: - schema_item = None - return ddl(None, schema_item, self, *params, **multiparams) - def _handle_dbapi_exception(self, e, statement, parameters, cursor, context): if getattr(self, '_reentrant_error', False): - raise exc.DBAPIError.instance(None, None, e) + # Py3K + #raise exc.DBAPIError.instance(statement, parameters, e) from e + # Py2K + raise exc.DBAPIError.instance(statement, parameters, e), None, sys.exc_info()[2] + # end Py2K self._reentrant_error = True try: if not isinstance(e, self.dialect.dbapi.Error): return - + if context: context.handle_dbapi_exception(e) - + is_disconnect = self.dialect.is_disconnect(e) if is_disconnect: self.invalidate(e) @@ -928,7 +1098,12 @@ class Connection(Connectable): self._autorollback() if self.__close_with_result: self.close() - raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) + # Py3K + #raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) from e + # Py2K + raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect), None, sys.exc_info()[2] + # end Py2K + finally: del self._reentrant_error @@ -966,7 +1141,7 @@ class Connection(Connectable): expression.ClauseElement: _execute_clauseelement, Compiled: _execute_compiled, schema.SchemaItem: _execute_default, - schema.DDL: _execute_ddl, + schema.DDLElement: _execute_ddl, basestring: _execute_text } @@ -991,6 +1166,7 @@ class Connection(Connectable): def run_callable(self, callable_): return callable_(self) + class Transaction(object): """Represent a Transaction in progress. @@ -998,14 +1174,13 @@ class Transaction(object): .. index:: single: thread safety; Transaction - """ def __init__(self, connection, parent): self.connection = connection self._parent = parent or self self.is_active = True - + def close(self): """Close this transaction. @@ -1016,6 +1191,7 @@ class Transaction(object): This is used to cancel a Transaction without affecting the scope of an enclosing transaction. """ + if not self._parent.is_active: return if self._parent is self: @@ -1048,6 +1224,7 @@ class Transaction(object): else: self.rollback() + class RootTransaction(Transaction): def __init__(self, connection): super(RootTransaction, self).__init__(connection, None) @@ -1059,6 +1236,7 @@ class RootTransaction(Transaction): def _do_commit(self): self.connection._commit_impl() + class NestedTransaction(Transaction): def __init__(self, connection, parent): super(NestedTransaction, self).__init__(connection, parent) @@ -1070,6 +1248,7 @@ class NestedTransaction(Transaction): def _do_commit(self): self.connection._release_savepoint_impl(self._savepoint, self._parent) + class TwoPhaseTransaction(Transaction): def __init__(self, connection, xid): super(TwoPhaseTransaction, self).__init__(connection, None) @@ -1089,9 +1268,10 @@ class TwoPhaseTransaction(Transaction): def commit(self): self.connection._commit_twophase_impl(self.xid, self._is_prepared) + class Engine(Connectable): """ - Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect` + Connects a :class:`~sqlalchemy.pool.Pool` and :class:`~sqlalchemy.engine.base.Dialect` together to provide a source of database connectivity and behavior. """ @@ -1111,9 +1291,15 @@ class Engine(Connectable): @property def name(self): "String name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." - + return self.dialect.name + @property + def driver(self): + "Driver name of the :class:`~sqlalchemy.engine.Dialect` in use by this ``Engine``." + + return self.dialect.driver + echo = log.echo_property() def __repr__(self): @@ -1126,12 +1312,16 @@ class Engine(Connectable): def create(self, entity, connection=None, **kwargs): """Create a table or index within this engine's database connection given a schema.Table object.""" - self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs) + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaGenerator, entity, connection=connection, **kwargs) def drop(self, entity, connection=None, **kwargs): """Drop a table or index within this engine's database connection given a schema.Table object.""" - self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs) + from sqlalchemy.engine import ddl + + self._run_visitor(ddl.SchemaDropper, entity, connection=connection, **kwargs) def _execute_default(self, default): connection = self.contextual_connect() @@ -1212,9 +1402,6 @@ class Engine(Connectable): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def statement_compiler(self, statement, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) - def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1231,12 +1418,10 @@ class Engine(Connectable): def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. - schema: - Optional, retrieve names from a non-default schema. + :param schema: Optional, retrieve names from a non-default schema. - connection: - Optional, use a specified connection. Default is the - ``contextual_connect`` for this ``Engine``. + :param connection: Optional, use a specified connection. Default is the + ``contextual_connect`` for this ``Engine``. """ if connection is None: @@ -1275,22 +1460,24 @@ class Engine(Connectable): return self.pool.unique_connection() + def _proxy_connection_cls(cls, proxy): class ProxyConnection(cls): def execute(self, object, *multiparams, **params): return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params) - + def _execute_clauseelement(self, elem, multiparams=None, params=None): return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {})) - + def _cursor_execute(self, cursor, statement, parameters, context=None): return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False) - + def _cursor_executemany(self, cursor, statement, parameters, context=None): return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True) return ProxyConnection + class RowProxy(object): """Proxy a single cursor row for a parent ResultProxy. @@ -1302,7 +1489,7 @@ class RowProxy(object): """ __slots__ = ['__parent', '__row'] - + def __init__(self, parent, row): """RowProxy objects are constructed by ResultProxy objects.""" @@ -1327,7 +1514,7 @@ class RowProxy(object): yield self.__parent._get_col(self.__row, i) __hash__ = None - + def __eq__(self, other): return ((other is self) or (other == tuple(self.__parent._get_col(self.__row, key) @@ -1362,18 +1549,19 @@ class RowProxy(object): """Return the list of keys as strings represented by this RowProxy.""" return self.__parent.keys - + def iterkeys(self): return iter(self.__parent.keys) - + def values(self): """Return the values represented by this RowProxy as a list.""" return list(self) - + def itervalues(self): return iter(self) + class BufferedColumnRow(RowProxy): def __init__(self, parent, row): row = [ResultProxy._get_col(parent, row, i) for i in xrange(len(row))] @@ -1403,9 +1591,8 @@ class ResultProxy(object): """ _process_row = RowProxy - + def __init__(self, context): - """ResultProxy objects are constructed via the execute() method on SQLEngine.""" self.context = context self.dialect = context.dialect self.closed = False @@ -1413,40 +1600,81 @@ class ResultProxy(object): self.connection = context.root_connection self._echo = context.engine._should_log_info self._init_metadata() - - @property + + @util.memoized_property def rowcount(self): - if self._rowcount is None: - return self.context.get_rowcount() - else: - return self._rowcount + """Return the 'rowcount' for this result. + + The 'rowcount' reports the number of rows affected + by an UPDATE or DELETE statement. It has *no* other + uses and is not intended to provide the number of rows + present from a SELECT. + + Additionally, this value is only meaningful if the + dialect's supports_sane_rowcount flag is True for + single-parameter executions, or supports_sane_multi_rowcount + is true for multiple parameter executions - otherwise + results are undefined. + + rowcount may not work at this time for a statement + that uses ``returning()``. + + """ + return self.context.rowcount @property def lastrowid(self): + """return the 'lastrowid' accessor on the DBAPI cursor. + + This is a DBAPI specific method and is only functional + for those backends which support it, for statements + where it is appropriate. It's behavior is not + consistent across backends. + + Usage of this method is normally unnecessary; the + inserted_primary_key method provides a + tuple of primary key values for a newly inserted row, + regardless of database backend. + + """ return self.cursor.lastrowid @property def out_parameters(self): return self.context.out_parameters - + + def _cursor_description(self): + return self.cursor.description + + def _autoclose(self): + if self.context.isinsert: + if self.context._is_implicit_returning: + self.context._fetch_implicit_returning(self) + self.close() + elif not self.context._is_explicit_returning: + self.close() + elif self._metadata is None: + # no results, get rowcount + # (which requires open cursor on some DB's such as firebird), + self.rowcount + self.close() # autoclose + + return self + + def _init_metadata(self): - metadata = self.cursor.description + self._metadata = metadata = self._cursor_description() if metadata is None: - # no results, get rowcount (which requires open cursor on some DB's such as firebird), - # then close - self._rowcount = self.context.get_rowcount() - self.close() return - - self._rowcount = None + self._props = util.populate_column_dict(None) self._props.creator = self.__key_fallback() self.keys = [] typemap = self.dialect.dbapi_type_map - for i, item in enumerate(metadata): - colname = item[0] + for i, (colname, coltype) in enumerate(m[0:2] for m in metadata): + if self.dialect.description_encoding: colname = colname.decode(self.dialect.description_encoding) @@ -1461,9 +1689,9 @@ class ResultProxy(object): try: (name, obj, type_) = self.context.result_map[colname.lower()] except KeyError: - (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE)) else: - (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) + (name, obj, type_) = (colname, None, typemap.get(coltype, types.NULLTYPE)) rec = (type_, type_.dialect_impl(self.dialect).result_processor(self.dialect), i) @@ -1474,7 +1702,10 @@ class ResultProxy(object): if origname: if self._props.setdefault(origname.lower(), rec) is not rec: self._props[origname.lower()] = (type_, self.__ambiguous_processor(origname), 0) - + + if self.dialect.requires_name_normalize: + colname = self.dialect.normalize_name(colname) + self.keys.append(colname) self._props[i] = rec if obj: @@ -1484,11 +1715,11 @@ class ResultProxy(object): if self._echo: self.context.engine.logger.debug( "Col " + repr(tuple(x[0] for x in metadata))) - + def __key_fallback(self): # create a closure without 'self' to avoid circular references props = self._props - + def fallback(key): if isinstance(key, basestring): key = key.lower() @@ -1515,19 +1746,22 @@ class ResultProxy(object): def close(self): """Close this ResultProxy. - + Closes the underlying DBAPI cursor corresponding to the execution. + + Note that any data cached within this ResultProxy is still available. + For some types of results, this may include buffered rows. If this ResultProxy was generated from an implicit execution, the underlying Connection will also be closed (returns the underlying DBAPI connection to the connection pool.) This method is called automatically when: - - * all result rows are exhausted using the fetchXXX() methods. - * cursor.description is None. - + + * all result rows are exhausted using the fetchXXX() methods. + * cursor.description is None. """ + if not self.closed: self.closed = True self.cursor.close() @@ -1550,53 +1784,66 @@ class ResultProxy(object): raise StopIteration else: yield row - - def last_inserted_ids(self): - """Return ``last_inserted_ids()`` from the underlying ExecutionContext. - - See ExecutionContext for details. + + @util.memoized_property + def inserted_primary_key(self): + """Return the primary key for the row just inserted. + + This only applies to single row insert() constructs which + did not explicitly specify returning(). """ - return self.context.last_inserted_ids() + if not self.context.isinsert: + raise exc.InvalidRequestError("Statement is not an insert() expression construct.") + elif self.context._is_explicit_returning: + raise exc.InvalidRequestError("Can't call inserted_primary_key when returning() is used.") + + return self.context._inserted_primary_key + @util.deprecated("Use inserted_primary_key") + def last_inserted_ids(self): + """deprecated. use inserted_primary_key.""" + + return self.inserted_primary_key + def last_updated_params(self): """Return ``last_updated_params()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.last_updated_params() def last_inserted_params(self): """Return ``last_inserted_params()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.last_inserted_params() def lastrow_has_defaults(self): """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.lastrow_has_defaults() def postfetch_cols(self): """Return ``postfetch_cols()`` from the underlying ExecutionContext. See ExecutionContext for details. - """ + return self.context.postfetch_cols - + def prefetch_cols(self): return self.context.prefetch_cols - + def supports_sane_rowcount(self): """Return ``supports_sane_rowcount`` from the dialect.""" - + return self.dialect.supports_sane_rowcount def supports_sane_multi_rowcount(self): @@ -1643,7 +1890,12 @@ class ResultProxy(object): raise def fetchmany(self, size=None): - """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``.""" + """Fetch many rows, just like DB-API ``cursor.fetchmany(size=cursor.arraysize)``. + + If rows are present, the cursor remains open after this is called. + Else the cursor is automatically closed and an empty list is returned. + + """ try: process_row = self._process_row @@ -1656,7 +1908,13 @@ class ResultProxy(object): raise def fetchone(self): - """Fetch one row, just like DB-API ``cursor.fetchone()``.""" + """Fetch one row, just like DB-API ``cursor.fetchone()``. + + If a row is present, the cursor remains open after this is called. + Else the cursor is automatically closed and None is returned. + + """ + try: row = self._fetchone_impl() if row is not None: @@ -1668,21 +1926,38 @@ class ResultProxy(object): self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) raise - def scalar(self): - """Fetch the first column of the first row, and close the result set.""" + def first(self): + """Fetch the first row and then close the result set unconditionally. + + Returns None if no row is present. + + """ try: row = self._fetchone_impl() except Exception, e: self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context) raise - + try: if row is not None: - return self._process_row(self, row)[0] + return self._process_row(self, row) else: return None finally: self.close() + + + def scalar(self): + """Fetch the first column of the first row, and close the result set. + + Returns None if no row is present. + + """ + row = self.first() + if row is not None: + return row[0] + else: + return None class BufferedRowResultProxy(ResultProxy): """A ResultProxy with row buffering behavior. @@ -1697,7 +1972,6 @@ class BufferedRowResultProxy(ResultProxy): The pre-fetching behavior fetches only one row initially, and then grows its buffer size by a fixed amount with each successive need for additional rows up to a size of 100. - """ def _init_metadata(self): @@ -1740,7 +2014,44 @@ class BufferedRowResultProxy(ResultProxy): return result def _fetchall_impl(self): - return self.__rowbuffer + list(self.cursor.fetchall()) + ret = self.__rowbuffer + list(self.cursor.fetchall()) + self.__rowbuffer[:] = [] + return ret + +class FullyBufferedResultProxy(ResultProxy): + """A result proxy that buffers rows fully upon creation. + + Used for operations where a result is to be delivered + after the database conversation can not be continued, + such as MSSQL INSERT...OUTPUT after an autocommit. + + """ + def _init_metadata(self): + super(FullyBufferedResultProxy, self)._init_metadata() + self.__rowbuffer = self._buffer_rows() + + def _buffer_rows(self): + return self.cursor.fetchall() + + def _fetchone_impl(self): + if self.__rowbuffer: + return self.__rowbuffer.pop(0) + else: + return None + + def _fetchmany_impl(self, size=None): + result = [] + for x in range(0, size): + row = self._fetchone_impl() + if row is None: + break + result.append(row) + return result + + def _fetchall_impl(self): + ret = self.__rowbuffer + self.__rowbuffer = [] + return ret class BufferedColumnResultProxy(ResultProxy): """A ResultProxy with column buffering behavior. @@ -1791,28 +2102,6 @@ class BufferedColumnResultProxy(ResultProxy): return l -class SchemaIterator(schema.SchemaVisitor): - """A visitor that can gather text into a buffer and execute the contents of the buffer.""" - - def __init__(self, connection): - """Construct a new SchemaIterator.""" - - self.connection = connection - self.buffer = StringIO.StringIO() - - def append(self, s): - """Append content to the SchemaIterator's query buffer.""" - - self.buffer.write(s) - - def execute(self): - """Execute the contents of the SchemaIterator's buffer.""" - - try: - return self.connection.execute(self.buffer.getvalue()) - finally: - self.buffer.truncate(0) - class DefaultRunner(schema.SchemaVisitor): """A visitor which accepts ColumnDefault objects, produces the dialect-specific SQL corresponding to their execution, and @@ -1821,7 +2110,6 @@ class DefaultRunner(schema.SchemaVisitor): DefaultRunners are used internally by Engines and Dialects. Specific database modules should provide their own subclasses of DefaultRunner to allow database-specific behavior. - """ def __init__(self, context): @@ -1854,7 +2142,7 @@ class DefaultRunner(schema.SchemaVisitor): def execute_string(self, stmt, params=None): """execute a string statement, using the raw cursor, and return a scalar result.""" - + conn = self.context._connection if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements: stmt = stmt.encode(self.dialect.encoding) @@ -1883,8 +2171,8 @@ def connection_memoize(key): Only applicable to functions which take no arguments other than a connection. The memo will be stored in ``connection.info[key]``. - """ + @util.decorator def decorated(fn, self, connection): connection = connection.connect() diff --git a/lib/sqlalchemy/engine/ddl.py b/lib/sqlalchemy/engine/ddl.py new file mode 100644 index 000000000..6e7253e9a --- /dev/null +++ b/lib/sqlalchemy/engine/ddl.py @@ -0,0 +1,128 @@ +# engine/ddl.py +# Copyright (C) 2009 Michael Bayer mike_mp@zzzcomputing.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Routines to handle CREATE/DROP workflow.""" + +from sqlalchemy import engine, schema +from sqlalchemy.sql import util as sql_util + + +class DDLBase(schema.SchemaVisitor): + def __init__(self, connection): + self.connection = connection + +class SchemaGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables and set(tables) or None + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def _can_create(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or not self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in sql_util.sort_tables(tables) if self._can_create(t)] + + for listener in metadata.ddl_listeners['before-create']: + listener('before-create', metadata, self.connection, tables=collection) + + for table in collection: + self.traverse_single(table) + + for listener in metadata.ddl_listeners['after-create']: + listener('after-create', metadata, self.connection, tables=collection) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-create']: + listener('before-create', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.CreateTable(table)) + + if hasattr(table, 'indexes'): + for index in table.indexes: + self.traverse_single(index) + + for listener in table.ddl_listeners['after-create']: + listener('after-create', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if ((not self.dialect.sequences_optional or + not sequence.optional) and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name))): + self.connection.execute(schema.CreateSequence(sequence)) + + def visit_index(self, index): + self.connection.execute(schema.CreateIndex(index)) + + +class SchemaDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def visit_metadata(self, metadata): + if self.tables: + tables = self.tables + else: + tables = metadata.tables.values() + collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)] + + for listener in metadata.ddl_listeners['before-drop']: + listener('before-drop', metadata, self.connection, tables=collection) + + for table in collection: + self.traverse_single(table) + + for listener in metadata.ddl_listeners['after-drop']: + listener('after-drop', metadata, self.connection, tables=collection) + + def _can_drop(self, table): + self.dialect.validate_identifier(table.name) + if table.schema: + self.dialect.validate_identifier(table.schema) + return not self.checkfirst or self.dialect.has_table(self.connection, table.name, schema=table.schema) + + def visit_index(self, index): + self.connection.execute(schema.DropIndex(index)) + + def visit_table(self, table): + for listener in table.ddl_listeners['before-drop']: + listener('before-drop', table, self.connection) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + self.connection.execute(schema.DropTable(table)) + + for listener in table.ddl_listeners['after-drop']: + listener('after-drop', table, self.connection) + + def visit_sequence(self, sequence): + if self.dialect.supports_sequences: + if ((not self.dialect.sequences_optional or + not sequence.optional) and + (not self.checkfirst or + self.dialect.has_sequence(self.connection, sequence.name))): + self.connection.execute(schema.DropSequence(sequence)) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 728b932a2..935d1e087 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -13,36 +13,59 @@ as the base class for their own corresponding classes. """ import re, random -from sqlalchemy.engine import base +from sqlalchemy.engine import base, reflection from sqlalchemy.sql import compiler, expression -from sqlalchemy import exc +from sqlalchemy import exc, types as sqltypes, util AUTOCOMMIT_REGEXP = re.compile(r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)', re.I | re.UNICODE) + class DefaultDialect(base.Dialect): """Default implementation of Dialect""" - name = 'default' - schemagenerator = compiler.SchemaGenerator - schemadropper = compiler.SchemaDropper - statement_compiler = compiler.DefaultCompiler + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler = compiler.GenericTypeCompiler preparer = compiler.IdentifierPreparer defaultrunner = base.DefaultRunner supports_alter = True + + supports_sequences = False + sequences_optional = False + preexecute_autoincrement_sequences = False + postfetch_lastrowid = True + implicit_returning = False + + # Py3K + #supports_unicode_statements = True + #supports_unicode_binds = True + # Py2K supports_unicode_statements = False + supports_unicode_binds = False + # end Py2K + + name = 'default' max_identifier_length = 9999 supports_sane_rowcount = True supports_sane_multi_rowcount = True - preexecute_pk_sequences = False - supports_pk_autoincrement = True dbapi_type_map = {} default_paramstyle = 'named' - supports_default_values = False + supports_default_values = False supports_empty_insert = True + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + + reflection_options = () def __init__(self, convert_unicode=False, assert_unicode=False, - encoding='utf-8', paramstyle=None, dbapi=None, + encoding='utf-8', paramstyle=None, dbapi=None, + implicit_returning=None, label_length=None, **kwargs): self.convert_unicode = convert_unicode self.assert_unicode = assert_unicode @@ -56,28 +79,58 @@ class DefaultDialect(base.Dialect): self.paramstyle = self.dbapi.paramstyle else: self.paramstyle = self.default_paramstyle + if implicit_returning is not None: + self.implicit_returning = implicit_returning self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) + self.type_compiler = self.type_compiler(self) + if label_length and label_length > self.max_identifier_length: - raise exc.ArgumentError("Label length of %d is greater than this dialect's maximum identifier length of %d" % (label_length, self.max_identifier_length)) + raise exc.ArgumentError("Label length of %d is greater than this dialect's" + " maximum identifier length of %d" % + (label_length, self.max_identifier_length)) self.label_length = label_length - self.description_encoding = getattr(self, 'description_encoding', encoding) - def type_descriptor(self, typeobj): + if not hasattr(self, 'description_encoding'): + self.description_encoding = getattr(self, 'description_encoding', encoding) + + # Py3K + ## work around dialects that might change these values + #self.supports_unicode_statements = True + #self.supports_unicode_binds = True + + def initialize(self, connection): + if hasattr(self, '_get_server_version_info'): + self.server_version_info = self._get_server_version_info(connection) + if hasattr(self, '_get_default_schema_name'): + self.default_schema_name = self._get_default_schema_name(connection) + + @classmethod + def type_descriptor(cls, typeobj): """Provide a database-specific ``TypeEngine`` object, given the generic object which comes from the types module. - Subclasses will usually use the ``adapt_type()`` method in the - types module to make this job easy.""" + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to ``types.adapt_type()``. - if type(typeobj) is type: - typeobj = typeobj() - return typeobj + """ + return sqltypes.adapt_type(typeobj, cls.colspecs) + + def reflecttable(self, connection, table, include_columns): + insp = reflection.Inspector.from_engine(connection) + return insp.reflecttable(table, include_columns) def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: - raise exc.IdentifierError("Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length)) - + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" % + (ident, self.max_identifier_length) + ) + + def connect(self, *cargs, **cparams): + return self.dbapi.connect(*cargs, **cparams) + def do_begin(self, connection): """Implementations might want to put logic here for turning autocommit on/off, etc. @@ -103,7 +156,8 @@ class DefaultDialect(base.Dialect): """Create a random two-phase transaction ID. This id will be passed to do_begin_twophase(), do_rollback_twophase(), - do_commit_twophase(). Its format is unspecified.""" + do_commit_twophase(). Its format is unspecified. + """ return "_sa_%032x" % random.randint(0, 2 ** 128) @@ -127,13 +181,30 @@ class DefaultDialect(base.Dialect): class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): + + def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None): self.dialect = dialect self._connection = self.root_connection = connection - self.compiled = compiled self.engine = connection.engine - if compiled is not None: + if compiled_ddl is not None: + self.compiled = compiled = compiled_ddl + if not dialect.supports_unicode_statements: + self.statement = unicode(compiled).encode(self.dialect.encoding) + else: + self.statement = unicode(compiled) + self.isinsert = self.isupdate = self.isdelete = self.executemany = False + self.should_autocommit = True + self.result_map = None + self.cursor = self.create_cursor() + self.compiled_parameters = [] + if self.dialect.positional: + self.parameters = [()] + else: + self.parameters = [{}] + elif compiled_sql is not None: + self.compiled = compiled = compiled_sql + # compiled clauseelement. process bind params, process table defaults, # track collections used by ResultProxy to target and process results @@ -156,6 +227,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.isinsert = compiled.isinsert self.isupdate = compiled.isupdate + self.isdelete = compiled.isdelete self.should_autocommit = compiled.statement._autocommit if isinstance(compiled.statement, expression._TextClause): self.should_autocommit = self.should_autocommit or self.should_autocommit_text(self.statement) @@ -173,31 +245,43 @@ class DefaultExecutionContext(base.ExecutionContext): self.parameters = self.__convert_compiled_params(self.compiled_parameters) elif statement is not None: - # plain text statement. - self.result_map = None + # plain text statement + self.result_map = self.compiled = None self.parameters = self.__encode_param_keys(parameters) self.executemany = len(parameters) > 1 if isinstance(statement, unicode) and not dialect.supports_unicode_statements: self.statement = statement.encode(self.dialect.encoding) else: self.statement = statement - self.isinsert = self.isupdate = False + self.isinsert = self.isupdate = self.isdelete = False self.cursor = self.create_cursor() self.should_autocommit = self.should_autocommit_text(statement) else: # no statement. used for standalone ColumnDefault execution. - self.statement = None - self.isinsert = self.isupdate = self.executemany = self.should_autocommit = False + self.statement = self.compiled = None + self.isinsert = self.isupdate = self.isdelete = self.executemany = self.should_autocommit = False self.cursor = self.create_cursor() - + + @util.memoized_property + def _is_explicit_returning(self): + return self.compiled and \ + getattr(self.compiled.statement, '_returning', False) + + @util.memoized_property + def _is_implicit_returning(self): + return self.compiled and \ + bool(self.compiled.returning) and \ + not self.compiled.statement._returning + @property def connection(self): return self._connection._branch() def __encode_param_keys(self, params): - """apply string encoding to the keys of dictionary-based bind parameters. + """Apply string encoding to the keys of dictionary-based bind parameters. - This is only used executing textual, non-compiled SQL expressions.""" + This is only used executing textual, non-compiled SQL expressions. + """ if self.dialect.positional or self.dialect.supports_unicode_statements: if params: @@ -216,7 +300,7 @@ class DefaultExecutionContext(base.ExecutionContext): return [proc(d) for d in params] or [{}] def __convert_compiled_params(self, compiled_parameters): - """convert the dictionary of bind parameter values into a dict or list + """Convert the dictionary of bind parameter values into a dict or list to be sent to the DBAPI's execute() or executemany() method. """ @@ -263,26 +347,69 @@ class DefaultExecutionContext(base.ExecutionContext): def post_exec(self): pass + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, + issuing a new SELECT on the cursor (or a new one), + or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects + which support "implicit" primary key generation, + keep preexecute_autoincrement_sequences set to False, + and when no explicit id value was bound to the + statement. + + The function is called once, directly after + post_exec() and before the transaction is committed + or ResultProxy is generated. If the post_exec() + method assigns a value to `self._lastrowid`, the + value is used in place of calling get_lastrowid(). + + Note that this method is *not* equivalent to the + ``lastrowid`` method on ``ResultProxy``, which is a + direct proxy to the DBAPI ``lastrowid`` accessor + in all cases. + + """ + + return self.cursor.lastrowid + def handle_dbapi_exception(self, e): pass def get_result_proxy(self): return base.ResultProxy(self) + + @property + def rowcount(self): + return self.cursor.rowcount - def get_rowcount(self): - if hasattr(self, '_rowcount'): - return self._rowcount - else: - return self.cursor.rowcount - def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - - def last_inserted_ids(self): - return self._last_inserted_ids + + def post_insert(self): + if self.dialect.postfetch_lastrowid and \ + (not len(self._inserted_primary_key) or \ + None in self._inserted_primary_key): + + table = self.compiled.statement.table + lastrowid = self.get_lastrowid() + self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] + + def _fetch_implicit_returning(self, resultproxy): + table = self.compiled.statement.table + row = resultproxy.first() + + self._inserted_primary_key = [v is not None and v or row[c] + for c, v in zip(table.primary_key, self._inserted_primary_key) + ] def last_inserted_params(self): return self._last_inserted_params @@ -293,12 +420,15 @@ class DefaultExecutionContext(base.ExecutionContext): def lastrow_has_defaults(self): return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) - def set_input_sizes(self): + def set_input_sizes(self, translate=None, exclude_types=None): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. """ + if not hasattr(self.compiled, 'bind_names'): + return + types = dict( (self.compiled.bind_names[bindparam], bindparam.type) for bindparam in self.compiled.bind_names) @@ -308,7 +438,7 @@ class DefaultExecutionContext(base.ExecutionContext): for key in self.compiled.positiontup: typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): inputsizes.append(dbtype) try: self.cursor.setinputsizes(*inputsizes) @@ -320,7 +450,9 @@ class DefaultExecutionContext(base.ExecutionContext): for key in self.compiled.bind_names.values(): typeengine = types[key] dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: + if dbtype is not None and (not exclude_types or dbtype not in exclude_types): + if translate: + key = translate.get(key, key) inputsizes[key.encode(self.dialect.encoding)] = dbtype try: self.cursor.setinputsizes(**inputsizes) @@ -329,8 +461,9 @@ class DefaultExecutionContext(base.ExecutionContext): raise def __process_defaults(self): - """generate default values for compiled insert/update statements, - and generate last_inserted_ids() collection.""" + """Generate default values for compiled insert/update statements, + and generate inserted_primary_key collection. + """ if self.executemany: if len(self.compiled.prefetch): @@ -364,7 +497,8 @@ class DefaultExecutionContext(base.ExecutionContext): compiled_parameters[c.key] = val if self.isinsert: - self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key] + self._inserted_primary_key = [compiled_parameters.get(c.key, None) + for c in self.compiled.statement.table.primary_key] self._last_inserted_params = compiled_parameters else: self._last_updated_params = compiled_parameters diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py new file mode 100644 index 000000000..173e0fab0 --- /dev/null +++ b/lib/sqlalchemy/engine/reflection.py @@ -0,0 +1,361 @@ +"""Provides an abstraction for obtaining database schema information. + +Usage Notes: + +Here are some general conventions when accessing the low level inspector +methods such as get_table_names, get_columns, etc. + +1. Inspector methods return lists of dicts in most cases for the following + reasons: + + * They're both standard types that can be serialized. + * Using a dict instead of a tuple allows easy expansion of attributes. + * Using a list for the outer structure maintains order and is easy to work + with (e.g. list comprehension [d['name'] for d in cols]). + +2. Records that contain a name, such as the column name in a column record + use the key 'name'. So for most return values, each record will have a + 'name' attribute.. +""" + +import sqlalchemy +from sqlalchemy import exc, sql +from sqlalchemy import util +from sqlalchemy.types import TypeEngine +from sqlalchemy import schema as sa_schema + + +@util.decorator +def cache(fn, self, con, *args, **kw): + info_cache = kw.get('info_cache', None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = ( + fn.__name__, + tuple(a for a in args if isinstance(a, basestring)), + tuple((k, v) for k, v in kw.iteritems() if isinstance(v, basestring)) + ) + ret = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + +class Inspector(object): + """Performs database schema inspection. + + The Inspector acts as a proxy to the dialects' reflection methods and + provides higher level functions for accessing database schema information. + """ + + def __init__(self, conn): + """Initialize the instance. + + :param conn: a :class:`~sqlalchemy.engine.base.Connectable` + """ + + self.conn = conn + # set the engine + if hasattr(conn, 'engine'): + self.engine = conn.engine + else: + self.engine = conn + self.dialect = self.engine.dialect + self.info_cache = {} + + @classmethod + def from_engine(cls, engine): + if hasattr(engine.dialect, 'inspector'): + return engine.dialect.inspector(engine) + return Inspector(engine) + + @property + def default_schema_name(self): + return self.dialect.get_default_schema_name(self.conn) + + def get_schema_names(self): + """Return all schema names. + """ + + if hasattr(self.dialect, 'get_schema_names'): + return self.dialect.get_schema_names(self.conn, + info_cache=self.info_cache) + return [] + + def get_table_names(self, schema=None, order_by=None): + """Return all table names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + :param order_by: Optional, may be the string "foreign_key" to sort + the result on foreign key dependencies. + + This should probably not return view names or maybe it should return + them with an indicator t or v. + """ + + if hasattr(self.dialect, 'get_table_names'): + tnames = self.dialect.get_table_names(self.conn, + schema, + info_cache=self.info_cache) + else: + tnames = self.engine.table_names(schema) + if order_by == 'foreign_key': + ordered_tnames = tnames[:] + # Order based on foreign key dependencies. + for tname in tnames: + table_pos = tnames.index(tname) + fkeys = self.get_foreign_keys(tname, schema) + for fkey in fkeys: + rtable = fkey['referred_table'] + if rtable in ordered_tnames: + ref_pos = ordered_tnames.index(rtable) + # Make sure it's lower in the list than anything it + # references. + if table_pos > ref_pos: + ordered_tnames.pop(table_pos) # rtable moves up 1 + # insert just below rtable + ordered_tnames.index(ref_pos, tname) + tnames = ordered_tnames + return tnames + + def get_table_options(self, table_name, schema=None, **kw): + if hasattr(self.dialect, 'get_table_options'): + return self.dialect.get_table_options(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + return {} + + def get_view_names(self, schema=None): + """Return all view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + """ + + return self.dialect.get_view_names(self.conn, schema, + info_cache=self.info_cache) + + def get_view_definition(self, view_name, schema=None): + """Return definition for `view_name`. + + :param schema: Optional, retrieve names from a non-default schema. + """ + + return self.dialect.get_view_definition( + self.conn, view_name, schema, info_cache=self.info_cache) + + def get_columns(self, table_name, schema=None, **kw): + """Return information about columns in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + column information as a list of dicts with these keys: + + name + the column's name + + type + :class:`~sqlalchemy.types.TypeEngine` + + nullable + boolean + + default + the column's default value + + attrs + dict containing optional column attributes + """ + + col_defs = self.dialect.get_columns(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + for col_def in col_defs: + # make this easy and only return instances for coltype + coltype = col_def['type'] + if not isinstance(coltype, TypeEngine): + col_def['type'] = coltype() + return col_defs + + def get_primary_keys(self, table_name, schema=None, **kw): + """Return information about primary keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + primary key information as a list of column names. + """ + + pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + + return pkeys + + def get_foreign_keys(self, table_name, schema=None, **kw): + """Return information about foreign_keys in `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + foreign key information as a list of dicts with these keys: + + constrained_columns + a list of column names that make up the foreign key + + referred_schema + the name of the referred schema + + referred_table + the name of the referred table + + referred_columns + a list of column names in the referred table that correspond to + constrained_columns + """ + + fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + return fk_defs + + def get_indexes(self, table_name, schema=None): + """Return information about indexes in `table_name`. + + Given a string `table_name` and an optional string `schema`, return + index information as a list of dicts with these keys: + + name + the index's name + + column_names + list of column names in order + + unique + boolean + """ + + indexes = self.dialect.get_indexes(self.conn, table_name, + schema, + info_cache=self.info_cache) + return indexes + + def reflecttable(self, table, include_columns): + + dialect = self.conn.dialect + + # MySQL dialect does this. Applicable with other dialects? + if hasattr(dialect, '_connection_charset') \ + and hasattr(dialect, '_adjust_casing'): + charset = dialect._connection_charset + dialect._adjust_casing(table) + + # table attributes we might need. + reflection_options = dict( + (k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs) + + schema = table.schema + table_name = table.name + + # apply table options + tbl_opts = self.get_table_options(table_name, schema, **table.kwargs) + if tbl_opts: + table.kwargs.update(tbl_opts) + + # table.kwargs will need to be passed to each reflection method. Make + # sure keywords are strings. + tblkw = table.kwargs.copy() + for (k, v) in tblkw.items(): + del tblkw[k] + tblkw[str(k)] = v + + # Py2K + if isinstance(schema, str): + schema = schema.decode(dialect.encoding) + if isinstance(table_name, str): + table_name = table_name.decode(dialect.encoding) + # end Py2K + + # columns + found_table = False + for col_d in self.get_columns(table_name, schema, **tblkw): + found_table = True + name = col_d['name'] + if include_columns and name not in include_columns: + continue + + coltype = col_d['type'] + col_kw = { + 'nullable':col_d['nullable'], + } + if 'autoincrement' in col_d: + col_kw['autoincrement'] = col_d['autoincrement'] + + colargs = [] + if col_d.get('default') is not None: + # the "default" value is assumed to be a literal SQL expression, + # so is wrapped in text() so that no quoting occurs on re-issuance. + colargs.append(sa_schema.DefaultClause(sql.text(col_d['default']))) + + if 'sequence' in col_d: + # TODO: whos using this ? + seq = col_d['sequence'] + sequence = sa_schema.Sequence(seq['name'], 1, 1) + if 'start' in seq: + sequence.start = seq['start'] + if 'increment' in seq: + sequence.increment = seq['increment'] + colargs.append(sequence) + + col = sa_schema.Column(name, coltype, *colargs, **col_kw) + table.append_column(col) + + if not found_table: + raise exc.NoSuchTableError(table.name) + + # Primary keys + primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[ + table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw) + if pk in table.c + ]) + + table.append_constraint(primary_key_constraint) + + # Foreign keys + fkeys = self.get_foreign_keys(table_name, schema, **tblkw) + for fkey_d in fkeys: + conname = fkey_d['name'] + constrained_columns = fkey_d['constrained_columns'] + referred_schema = fkey_d['referred_schema'] + referred_table = fkey_d['referred_table'] + referred_columns = fkey_d['referred_columns'] + refspec = [] + if referred_schema is not None: + sa_schema.Table(referred_table, table.metadata, + autoload=True, schema=referred_schema, + autoload_with=self.conn, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join( + [referred_schema, referred_table, column])) + else: + sa_schema.Table(referred_table, table.metadata, autoload=True, + autoload_with=self.conn, + **reflection_options + ) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + table.append_constraint( + sa_schema.ForeignKeyConstraint(constrained_columns, refspec, + conname, link_to_name=True)) + # Indexes + indexes = self.get_indexes(table_name, schema) + for index_d in indexes: + name = index_d['name'] + columns = index_d['column_names'] + unique = index_d['unique'] + flavor = index_d.get('type', 'unknown type') + if include_columns and \ + not set(columns).issubset(include_columns): + util.warn( + "Omitting %s KEY for (%s), key covers omitted columns." % + (flavor, ', '.join(columns))) + continue + sa_schema.Index(name, *[table.columns[c] for c in columns], + **dict(unique=unique)) diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index fa608df65..ff62b265b 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -6,31 +6,26 @@ underlying behavior for the "strategy" keyword argument available on ``plain``, ``threadlocal``, and ``mock``. New strategies can be added via new ``EngineStrategy`` classes. - """ + from operator import attrgetter from sqlalchemy.engine import base, threadlocal, url from sqlalchemy import util, exc from sqlalchemy import pool as poollib - strategies = {} + class EngineStrategy(object): """An adaptor that processes input arguements and produces an Engine. Provides a ``create`` method that receives input arguments and produces an instance of base.Engine or a subclass. + """ - def __init__(self, name): - """Construct a new EngineStrategy object. - - Sets it in the list of available strategies under this name. - """ - - self.name = name + def __init__(self): strategies[self.name] = self def create(self, *args, **kwargs): @@ -38,9 +33,12 @@ class EngineStrategy(object): raise NotImplementedError() + class DefaultEngineStrategy(EngineStrategy): """Base class for built-in stratgies.""" + pool_threadlocal = False + def create(self, name_or_url, **kwargs): # create url.URL object u = url.make_url(name_or_url) @@ -75,9 +73,15 @@ class DefaultEngineStrategy(EngineStrategy): if pool is None: def connect(): try: - return dbapi.connect(*cargs, **cparams) + return dialect.connect(*cargs, **cparams) except Exception, e: - raise exc.DBAPIError.instance(None, None, e) + # Py3K + #raise exc.DBAPIError.instance(None, None, e) from e + # Py2K + import sys + raise exc.DBAPIError.instance(None, None, e), None, sys.exc_info()[2] + # end Py2K + creator = kwargs.pop('creator', connect) poolclass = (kwargs.pop('poolclass', None) or @@ -94,7 +98,7 @@ class DefaultEngineStrategy(EngineStrategy): tk = translate.get(k, k) if tk in kwargs: pool_args[k] = kwargs.pop(tk) - pool_args.setdefault('use_threadlocal', self.pool_threadlocal()) + pool_args.setdefault('use_threadlocal', self.pool_threadlocal) pool = poolclass(creator, **pool_args) else: if isinstance(pool, poollib._DBProxy): @@ -103,12 +107,14 @@ class DefaultEngineStrategy(EngineStrategy): pool = pool # create engine. - engineclass = self.get_engine_cls() + engineclass = self.engine_cls engine_args = {} for k in util.get_cls_kwargs(engineclass): if k in kwargs: engine_args[k] = kwargs.pop(k) + _initialize = kwargs.pop('_initialize', True) + # all kwargs should be consumed if kwargs: raise TypeError( @@ -119,39 +125,38 @@ class DefaultEngineStrategy(EngineStrategy): dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) - return engineclass(pool, dialect, u, **engine_args) + + engine = engineclass(pool, dialect, u, **engine_args) - def pool_threadlocal(self): - raise NotImplementedError() + if _initialize: + # some unit tests pass through _initialize=False + # to help mock engines work + class OnInit(object): + def first_connect(self, conn, rec): + c = base.Connection(engine, connection=conn) + dialect.initialize(c) + pool._on_first_connect.insert(0, OnInit()) - def get_engine_cls(self): - raise NotImplementedError() + dialect.visit_pool(pool) -class PlainEngineStrategy(DefaultEngineStrategy): - """Strategy for configuring a regular Engine.""" + return engine - def __init__(self): - DefaultEngineStrategy.__init__(self, 'plain') - - def pool_threadlocal(self): - return False - def get_engine_cls(self): - return base.Engine +class PlainEngineStrategy(DefaultEngineStrategy): + """Strategy for configuring a regular Engine.""" + name = 'plain' + engine_cls = base.Engine + PlainEngineStrategy() + class ThreadLocalEngineStrategy(DefaultEngineStrategy): """Strategy for configuring an Engine with thredlocal behavior.""" - - def __init__(self): - DefaultEngineStrategy.__init__(self, 'threadlocal') - - def pool_threadlocal(self): - return True - - def get_engine_cls(self): - return threadlocal.TLEngine + + name = 'threadlocal' + pool_threadlocal = True + engine_cls = threadlocal.TLEngine ThreadLocalEngineStrategy() @@ -161,11 +166,11 @@ class MockEngineStrategy(EngineStrategy): Produces a single mock Connectable object which dispatches statement execution to a passed-in function. + """ - def __init__(self): - EngineStrategy.__init__(self, 'mock') - + name = 'mock' + def create(self, name_or_url, executor, **kwargs): # create url.URL object u = url.make_url(name_or_url) @@ -201,11 +206,14 @@ class MockEngineStrategy(EngineStrategy): def create(self, entity, **kwargs): kwargs['checkfirst'] = False - self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity) + from sqlalchemy.engine import ddl + + ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse(entity) def drop(self, entity, **kwargs): kwargs['checkfirst'] = False - self.dialect.schemadropper(self.dialect, self, **kwargs).traverse(entity) + from sqlalchemy.engine import ddl + ddl.SchemaDropper(self.dialect, self, **kwargs).traverse(entity) def execute(self, object, *multiparams, **params): raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 8ad14ad35..27d857623 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -8,6 +8,7 @@ invoked automatically when the threadlocal engine strategy is used. from sqlalchemy import util from sqlalchemy.engine import base + class TLSession(object): def __init__(self, engine): self.engine = engine @@ -17,7 +18,8 @@ class TLSession(object): try: return self.__transaction._increment_connect() except AttributeError: - return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result) + return self.engine.TLConnection(self, self.engine.pool.connect(), + close_with_result=close_with_result) def reset(self): try: diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 5c8e68ce4..b0e21f5f7 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -20,9 +20,9 @@ class URL(object): format of the URL is an RFC-1738-style string. All initialization parameters are available as public attributes. - - :param drivername: the name of the database backend. - This name will correspond to a module in sqlalchemy/databases + + :param drivername: the name of the database backend. + This name will correspond to a module in sqlalchemy/databases or a third party plug-in. :param username: The user name. @@ -35,12 +35,13 @@ class URL(object): :param database: The database name. - :param query: A dictionary of options to be passed to the + :param query: A dictionary of options to be passed to the dialect and/or the DBAPI upon connect. - + """ - def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None, query=None): + def __init__(self, drivername, username=None, password=None, + host=None, port=None, database=None, query=None): self.drivername = drivername self.username = username self.password = password @@ -70,10 +71,10 @@ class URL(object): keys.sort() s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys) return s - + def __hash__(self): return hash(str(self)) - + def __eq__(self, other): return \ isinstance(other, URL) and \ @@ -83,12 +84,22 @@ class URL(object): self.host == other.host and \ self.database == other.database and \ self.query == other.query - + def get_dialect(self): - """Return the SQLAlchemy database dialect class corresponding to this URL's driver name.""" - + """Return the SQLAlchemy database dialect class corresponding + to this URL's driver name. + """ + try: - module = getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) + if '+' in self.drivername: + dialect, driver = self.drivername.split('+') + else: + dialect, driver = self.drivername, 'base' + + module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + module = getattr(module, dialect) + module = getattr(module, driver) + return module.dialect except ImportError: if sys.exc_info()[2].tb_next is None: @@ -97,7 +108,7 @@ class URL(object): if res.name == self.drivername: return res.load() raise - + def translate_connect_args(self, names=[], **kw): """Translate url attributes into a dictionary of connection arguments. @@ -107,10 +118,9 @@ class URL(object): from the final dictionary. :param \**kw: Optional, alternate key names for url attributes. - + :param names: Deprecated. Same purpose as the keyword-based alternate names, but correlates the name to the original positionally. - """ translated = {} @@ -131,8 +141,8 @@ def make_url(name_or_url): The given string is parsed according to the RFC 1738 spec. If an existing URL object is passed, just returns the object. - """ + if isinstance(name_or_url, basestring): return _parse_rfc1738_args(name_or_url) else: @@ -140,7 +150,7 @@ def make_url(name_or_url): def _parse_rfc1738_args(name): pattern = re.compile(r''' - (?P<name>\w+):// + (?P<name>[\w\+]+):// (?: (?P<username>[^:/]*) (?::(?P<password>[^/]*))? @@ -160,8 +170,10 @@ def _parse_rfc1738_args(name): tokens = components['database'].split('?', 2) components['database'] = tokens[0] query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None + # Py2K if query is not None: query = dict((k.encode('ascii'), query[k]) for k in query) + # end Py2K else: query = None components['query'] = query |
