diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/ext/proxy.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/mapping/objectstore.py | 209 | ||||
| -rw-r--r-- | lib/sqlalchemy/mapping/query.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/schema.py | 55 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql.py | 66 |
5 files changed, 204 insertions, 141 deletions
diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py index 38325bea3..a24f089e9 100644 --- a/lib/sqlalchemy/ext/proxy.py +++ b/lib/sqlalchemy/ext/proxy.py @@ -24,7 +24,15 @@ class BaseProxyEngine(schema.SchemaEngine): def reflecttable(self, table): return self.get_engine().reflecttable(table) - + def execute_compiled(self, *args, **kwargs): + return self.get_engine().execute_compiled(*args, **kwargs) + def compiler(self, *args, **kwargs): + return self.get_engine().compiler(*args, **kwargs) + def schemagenerator(self, *args, **kwargs): + return self.get_engine().schemagenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return self.get_engine().schemadropper(*args, **kwargs) + def hash_key(self): return "%s(%s)" % (self.__class__.__name__, id(self)) diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index 1491d39ac..faf5ddbd6 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -17,7 +17,7 @@ import sqlalchemy class Session(object): """Maintains a UnitOfWork instance, including transaction state.""" - def __init__(self, nest_on=None, hash_key=None): + def __init__(self, hash_key=None, new_imap=True, import_session=None): """Initialize the objectstore with a UnitOfWork registry. If called with no arguments, creates a single UnitOfWork for all operations. @@ -26,31 +26,23 @@ class Session(object): hash_key - the hash_key used to identify objects against this session, which defaults to the id of the Session instance. """ - self.uow = unitofwork.UnitOfWork() - self.parent_uow = None - self.begin_count = 0 - self.nest_on = util.to_list(nest_on) - self.__pushed_count = 0 + if import_session is not None: + self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map) + elif new_imap is False: + self.uow = unitofwork.UnitOfWork(identity_map=objectstore.get_session().uow.identity_map) + else: + self.uow = unitofwork.UnitOfWork() + + self.binds = {} if hash_key is None: self.hash_key = id(self) else: self.hash_key = hash_key _sessions[self.hash_key] = self - def was_pushed(self): - if self.nest_on is None: - return - self.__pushed_count += 1 - if self.__pushed_count == 1: - for n in self.nest_on: - n.push_session() - def was_popped(self): - if self.nest_on is None or self.__pushed_count == 0: - return - self.__pushed_count -= 1 - if self.__pushed_count == 0: - for n in self.nest_on: - n.pop_session() + def bind_table(self, table, bindto): + self.binds[table] = bindto + def get_id_key(ident, class_, entity_name=None): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a tuple of the object's primary key values. @@ -81,79 +73,12 @@ class Session(object): """ return (class_, tuple([row[column] for column in primary_key]), entity_name) get_row_key = staticmethod(get_row_key) - - class SessionTrans(object): - """returned by Session.begin(), denotes a transactionalized UnitOfWork instance. - call commit() on this to commit the transaction.""" - def __init__(self, parent, uow, isactive): - self.__parent = parent - self.__isactive = isactive - self.__uow = uow - isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.") - parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.") - uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.") - def begin(self): - """calls begin() on the underlying Session object, returning a new no-op SessionTrans object.""" - if self.parent.uow is not self.uow: - raise InvalidRequestError("This SessionTrans is no longer valid") - return self.parent.begin() - def commit(self): - """commits the transaction noted by this SessionTrans object.""" - self.__parent._trans_commit(self) - self.__isactive = False - def rollback(self): - """rolls back the current UnitOfWork transaction, in the case that begin() - has been called. The changes logged since the begin() call are discarded.""" - self.__parent._trans_rollback(self) - self.__isactive = False - - def begin(self): - """begins a new UnitOfWork transaction and returns a tranasaction-holding - object. commit() or rollback() should be called on the returned object. - commit() on the Session will do nothing while a transaction is pending, and further - calls to begin() will return no-op transactional objects.""" - if self.parent_uow is not None: - return Session.SessionTrans(self, self.uow, False) - self.parent_uow = self.uow - self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map) - return Session.SessionTrans(self, self.uow, True) def engines(self, mapper): return [t.engine for t in mapper.tables] - def _trans_commit(self, trans): - if trans.uow is self.uow and trans.isactive: - try: - self._commit_uow() - finally: - self.uow = self.parent_uow - self.parent_uow = None - def _trans_rollback(self, trans): - if trans.uow is self.uow: - self.uow = self.parent_uow - self.parent_uow = None - - def _commit_uow(self, *obj): - self.was_pushed() - try: - self.uow.flush(self, *obj) - finally: - self.was_popped() - - def commit(self, *objects): - """commits the current UnitOfWork transaction. called with - no arguments, this is only used - for "implicit" transactions when there was no begin(). - if individual objects are submitted, then only those objects are committed, and the - begin/commit cycle is not affected.""" - # if an object list is given, commit just those but dont - # change begin/commit status - if len(objects): - self._commit_uow(*objects) - self.uow.flush(self, *objects) - return - if self.parent_uow is None: - self._commit_uow() + def flush(self, *obj): + self.uow.flush(self, *obj) def refresh(self, *obj): """reloads the attributes for the given objects from the database, clears @@ -221,6 +146,95 @@ class Session(object): u.register_new(instance) return instance +class LegacySession(Session): + def __init__(self, nest_on=None, hash_key=None, **kwargs): + super(LegacySession, self).__init__(**kwargs) + self.parent_uow = None + self.begin_count = 0 + self.nest_on = util.to_list(nest_on) + self.__pushed_count = 0 + def was_pushed(self): + if self.nest_on is None: + return + self.__pushed_count += 1 + if self.__pushed_count == 1: + for n in self.nest_on: + n.push_session() + def was_popped(self): + if self.nest_on is None or self.__pushed_count == 0: + return + self.__pushed_count -= 1 + if self.__pushed_count == 0: + for n in self.nest_on: + n.pop_session() + class SessionTrans(object): + """returned by Session.begin(), denotes a transactionalized UnitOfWork instance. + call commit() on this to commit the transaction.""" + def __init__(self, parent, uow, isactive): + self.__parent = parent + self.__isactive = isactive + self.__uow = uow + isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.") + parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.") + uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.") + def begin(self): + """calls begin() on the underlying Session object, returning a new no-op SessionTrans object.""" + if self.parent.uow is not self.uow: + raise InvalidRequestError("This SessionTrans is no longer valid") + return self.parent.begin() + def commit(self): + """commits the transaction noted by this SessionTrans object.""" + self.__parent._trans_commit(self) + self.__isactive = False + def rollback(self): + """rolls back the current UnitOfWork transaction, in the case that begin() + has been called. The changes logged since the begin() call are discarded.""" + self.__parent._trans_rollback(self) + self.__isactive = False + def begin(self): + """begins a new UnitOfWork transaction and returns a tranasaction-holding + object. commit() or rollback() should be called on the returned object. + commit() on the Session will do nothing while a transaction is pending, and further + calls to begin() will return no-op transactional objects.""" + if self.parent_uow is not None: + return Session.SessionTrans(self, self.uow, False) + self.parent_uow = self.uow + self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map) + return Session.SessionTrans(self, self.uow, True) + def commit(self, *objects): + """commits the current UnitOfWork transaction. called with + no arguments, this is only used + for "implicit" transactions when there was no begin(). + if individual objects are submitted, then only those objects are committed, and the + begin/commit cycle is not affected.""" + # if an object list is given, commit just those but dont + # change begin/commit status + if len(objects): + self._commit_uow(*objects) + self.uow.flush(self, *objects) + return + if self.parent_uow is None: + self._commit_uow() + def _trans_commit(self, trans): + if trans.uow is self.uow and trans.isactive: + try: + self._commit_uow() + finally: + self.uow = self.parent_uow + self.parent_uow = None + def _trans_rollback(self, trans): + if trans.uow is self.uow: + self.uow = self.parent_uow + self.parent_uow = None + def _commit_uow(self, *obj): + self.was_pushed() + try: + self.uow.flush(self, *obj) + finally: + self.was_popped() + +Session = LegacySession + def get_id_key(ident, class_, entity_name=None): return Session.get_id_key(ident, class_, entity_name) @@ -228,19 +242,22 @@ def get_row_key(row, class_, primary_key, entity_name=None): return Session.get_row_key(row, class_, primary_key, entity_name) def begin(): - """begins a new UnitOfWork transaction. the next commit will affect only - objects that are created, modified, or deleted following the begin statement.""" + """deprecated. use s = Session(new_imap=False).""" return get_session().begin() def commit(*obj): - """commits the current UnitOfWork transaction. if a transaction was begun - via begin(), commits only those objects that were created, modified, or deleted - since that begin statement. otherwise commits all objects that have been + """deprecated; use flush(*obj)""" + get_session().flush(*obj) + +def flush(*obj): + """flushes the current UnitOfWork transaction. if a transaction was begun + via begin(), flushes only those objects that were created, modified, or deleted + since that begin statement. otherwise flushes all objects that have been changed. - + if individual objects are submitted, then only those objects are committed, and the begin/commit cycle is not affected.""" - get_session().commit(*obj) + get_session().flush(*obj) def clear(): """removes all current UnitOfWorks and IdentityMaps for this thread and diff --git a/lib/sqlalchemy/mapping/query.py b/lib/sqlalchemy/mapping/query.py index 09c2b9b6e..950c2be42 100644 --- a/lib/sqlalchemy/mapping/query.py +++ b/lib/sqlalchemy/mapping/query.py @@ -10,6 +10,7 @@ class Query(object): self.mapper = mapper self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) self.order_by = kwargs.pop('order_by', self.mapper.order_by) + self.extension = kwargs.pop('extension', self.mapper.extension) self._session = kwargs.pop('session', None) if not hasattr(mapper, '_get_clause'): _get_clause = sql.and_() @@ -66,7 +67,7 @@ class Query(object): e.g. result = usermapper.select_by(user_name = 'fred') """ - ret = self.mapper.extension.select_by(self, *args, **params) + ret = self.extension.select_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret return self.select_whereclause(self._by_clause(*args, **params)) @@ -116,7 +117,7 @@ class Query(object): in this case, the developer must insure that an adequate set of columns exists in the rowset with which to build new object instances.""" - ret = self.mapper.extension.select(self, arg=arg, **kwargs) + ret = self.extension.select(self, arg=arg, **kwargs) if ret is not mapper.EXT_PASS: return ret elif arg is not None and isinstance(arg, sql.Selectable): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 24392b3d9..acce555ab 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -23,8 +23,17 @@ import copy, re, string __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] +class SchemaMeta(type): + """provides universal constructor arguments for all SchemaItems""" + def __call__(self, *args, **kwargs): + engine = kwargs.pop('engine', None) + obj = type.__call__(self, *args, **kwargs) + obj._engine = engine + return obj + class SchemaItem(object): """base class for items that define a database schema.""" + __metaclass__ = SchemaMeta def _init_items(self, *args): for item in args: if item is not None: @@ -34,7 +43,20 @@ class SchemaItem(object): raise NotImplementedError() def __repr__(self): return "%s()" % self.__class__.__name__ - + +class EngineMixin(object): + """a mixin for SchemaItems that provides an "engine" accessor.""" + def _derived_engine(self): + """subclasses override this method to return an AbstractEngine + bound to a parent item""" + return None + def _get_engine(self): + if self._engine is not None: + return self._engine + else: + return self._derived_engine() + engine = property(_get_engine) + def _get_table_key(engine, name, schema): if schema is not None and schema == engine.get_default_schema_name(): schema = None @@ -43,14 +65,12 @@ def _get_table_key(engine, name, schema): else: return schema + "." + name -class TableSingleton(type): +class TableSingleton(SchemaMeta): """a metaclass used by the Table object to provide singleton behavior.""" def __call__(self, name, engine=None, *args, **kwargs): try: - if not isinstance(engine, SchemaEngine): + if engine is not None and not isinstance(engine, SchemaEngine): args = [engine] + list(args) - engine = None - if engine is None: engine = default_engine name = str(name) # in case of incoming unicode schema = kwargs.get('schema', None) @@ -58,6 +78,10 @@ class TableSingleton(type): redefine = kwargs.pop('redefine', False) mustexist = kwargs.pop('mustexist', False) useexisting = kwargs.pop('useexisting', False) + if not engine: + table = type.__call__(self, name, engine, **kwargs) + table._init_items(*args) + return table key = _get_table_key(engine, name, schema) table = engine.tables[key] if len(args): @@ -440,15 +464,14 @@ class ForeignKey(SchemaItem): self.parent.foreign_key = self self.parent.table.foreign_keys.append(self) -class DefaultGenerator(SchemaItem): +class DefaultGenerator(SchemaItem, EngineMixin): """Base class for column "default" values.""" - def __init__(self, for_update=False, engine=None): + def __init__(self, for_update=False): self.for_update = for_update - self.engine = engine + def _derived_engine(self): + return self.column.table.engine def _set_parent(self, column): self.column = column - if self.engine is None: - self.engine = column.table.engine if self.for_update: self.column.onupdate = self else: @@ -509,7 +532,7 @@ class Sequence(DefaultGenerator): return visitor.visit_sequence(self) -class Index(SchemaItem): +class Index(SchemaItem, EngineMixin): """Represents an index of columns from a database table """ def __init__(self, name, *columns, **kw): @@ -530,7 +553,8 @@ class Index(SchemaItem): self.unique = kw.pop('unique', False) self._init_items(*columns) - engine = property(lambda s:s.table.engine) + def _derived_engine(self): + return self.table.engine def _init_items(self, *args): for column in args: self.append_column(column) @@ -570,18 +594,21 @@ class Index(SchemaItem): for c in self.columns]), (self.unique and ', unique=True') or '') -class SchemaEngine(object): +class SchemaEngine(sql.AbstractEngine): """a factory object used to create implementations for schema objects. This object is the ultimate base class for the engine.SQLEngine class.""" def __init__(self): # a dictionary that stores Table objects keyed off their name (and possibly schema name) self.tables = {} - def reflecttable(self, table): """given a table, will query the database and populate its Column and ForeignKey objects.""" raise NotImplementedError() + def schemagenerator(self, **params): + raise NotImplementedError() + def schemadropper(self, **params): + raise NotImplementedError() class SchemaVisitor(sql.ClauseVisitor): """defines the visiting for SchemaItem objects""" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index f6e2d03c9..2bc025e9f 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -246,6 +246,12 @@ def _is_literal(element): def is_column(col): return isinstance(col, ColumnElement) +class AbstractEngine(object): + def execute_compiled(self, compiled, parameters, echo=None, **kwargs): + raise NotImplementedError() + def compiler(self, statement, parameters, **kwargs): + raise NotImplementedError() + class ClauseParameters(util.OrderedDict): """represents a dictionary/iterator of bind parameter key names/values. Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method. Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects. The non-converted value can be retrieved via the get_original method. For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order.""" def __init__(self, engine=None): @@ -340,8 +346,11 @@ class Compiled(ClauseVisitor): """executes this compiled object using the underlying SQLEngine""" if len(multiparams): params = multiparams - - return self.engine.execute_compiled(self, params) + + e = self.engine + if e is None: + raise InvalidRequestError("This Compiled object is not bound to any engine.") + return e.execute_compiled(self, params) def scalar(self, *multiparams, **params): """executes this compiled object via the execute() method, then @@ -356,7 +365,26 @@ class Compiled(ClauseVisitor): return row[0] else: return None - + +class Executor(object): + """handles the compilation/execution of a ClauseElement within the context of a particular AbtractEngine. This + AbstractEngine will usually be a SQLEngine or ConnectionProxy.""" + def __init__(self, clauseelement, abstractengine=None): + self.engine=abstractengine + self.clauseelement = clauseelement + def execute(self, *multiparams, **params): + return self.compile(*multiparams, **params).execute(*multiparams, **params) + def scalar(self, *multiparams, **params): + return self.compile(*multiparams, **params).scalar(*multiparams, **params) + def compile(self, *multiparams, **params): + if len(multiparams): + bindparams = multiparams[0] + else: + bindparams = params + compiler = self.engine.compiler(self.clauseelement, bindparams) + compiler.compile() + return compiler + class ClauseElement(object): """base class for elements of a programmatically constructed SQL expression.""" def _get_from_objects(self): @@ -415,10 +443,12 @@ class ClauseElement(object): engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.") - + def using(self, abstractengine): + return Executor(self, abstractengine) + def compile(self, engine = None, parameters = None, typemap=None, compiler=None): """compiles this SQL expression using its underlying SQLEngine to produce - a Compiled object. If no engine can be found, an ansisql engine is used. + a Compiled object. If no engine can be found, an ANSICompiler is used with no engine. bindparams is a dictionary representing the default bind parameters to be used with the statement. """ @@ -430,7 +460,7 @@ class ClauseElement(object): if compiler is None: import sqlalchemy.ansisql as ansisql - compiler = ansisql.ANSICompiler(self, parameters=parameters, typemap=typemap) + compiler = ansisql.ANSICompiler(self, parameters=parameters) compiler.compile() return compiler @@ -438,30 +468,10 @@ class ClauseElement(object): return str(self.compile()) def execute(self, *multiparams, **params): - """compiles and executes this SQL expression using its underlying SQLEngine. the - given **params are used as bind parameters when compiling and executing the - expression. the DBAPI cursor object is returned.""" - e = self.engine - if len(multiparams): - bindparams = multiparams[0] - else: - bindparams = params - c = self.compile(e, parameters=bindparams) - return c.execute(*multiparams, **params) + return self.using(self.engine).execute(*multiparams, **params) def scalar(self, *multiparams, **params): - """executes this SQL expression via the execute() method, then - returns the first column of the first row. Useful for executing functions, - sequences, rowcounts, etc.""" - # we are still going off the assumption that fetching only the first row - # in a result set is not performance-wise any different than specifying limit=1 - # else we'd have to construct a copy of the select() object with the limit - # installed (else if we change the existing select, not threadsafe) - row = self.execute(*multiparams, **params).fetchone() - if row is not None: - return row[0] - else: - return None + return self.using(self.engine).scalar(*multiparams, **params) def __and__(self, other): return and_(self, other) |
