summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/strategies.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/strategies.py')
-rw-r--r--lib/sqlalchemy/orm/strategies.py500
1 files changed, 500 insertions, 0 deletions
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
new file mode 100644
index 000000000..e51dd5abd
--- /dev/null
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -0,0 +1,500 @@
+# strategies.py
+# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from sqlalchemy import sql, schema, util, attributes, exceptions, sql_util, logging
+import mapper
+from interfaces import *
+import session as sessionlib
+import util as mapperutil
+import sets, random
+
+
+class ColumnLoader(LoaderStrategy):
+ def init(self):
+ super(ColumnLoader, self).init()
+ self.columns = self.parent_property.columns
+ def setup_query(self, context, eagertable=None, **kwargs):
+ for c in self.columns:
+ if eagertable is not None:
+ context.statement.append_column(eagertable.corresponding_column(c))
+ else:
+ context.statement.append_column(c)
+
+ def init_class_attribute(self):
+ self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=lambda x: self.columns[0].type.copy_value(x), compare_function=lambda x,y:self.columns[0].type.compare_values(x,y), mutable_scalars=self.columns[0].type.is_mutable())
+
+ def process_row(self, selectcontext, instance, row, identitykey, isnew):
+ if isnew:
+ self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
+ instance.__dict__[self.key] = row[self.columns[0]]
+
+ColumnLoader.logger = logging.class_logger(ColumnLoader)
+
+class DeferredColumnLoader(LoaderStrategy):
+ """describes an object attribute that corresponds to a table column, which also
+ will "lazy load" its value from the table. this is per-column lazy loading."""
+ def init(self):
+ super(DeferredColumnLoader, self).init()
+ self.columns = self.parent_property.columns
+ self.group = self.parent_property.group
+
+ def init_class_attribute(self):
+ self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
+ sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=lambda i:self.setup_loader(i), copy_function=lambda x: self.columns[0].type.copy_value(x), compare_function=lambda x,y:self.columns[0].type.compare_values(x,y), mutable_scalars=self.columns[0].type.is_mutable())
+
+ def setup_query(self, context, **kwargs):
+ pass
+
+ def process_row(self, selectcontext, instance, row, identitykey, isnew):
+ if isnew:
+ if not self.is_default or len(selectcontext.options):
+ sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self.setup_loader(instance, selectcontext.options))
+ else:
+ sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+
+ def setup_loader(self, instance, options=None):
+ if not mapper.has_mapper(instance):
+ return None
+ else:
+ prop = mapper.object_mapper(instance).props[self.key]
+ if prop is not self.parent_property:
+ return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
+ def lazyload():
+ self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), str(self.group)))
+ try:
+ pk = self.parent.pks_by_table[self.columns[0].table]
+ except KeyError:
+ pk = self.columns[0].table.primary_key
+
+ clause = sql.and_()
+ for primary_key in pk:
+ attr = self.parent._getattrbycolumn(instance, primary_key)
+ if not attr:
+ return None
+ clause.clauses.append(primary_key == attr)
+
+ session = sessionlib.object_session(instance)
+ if session is None:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+ localparent = mapper.object_mapper(instance)
+ if self.group is not None:
+ groupcols = [p for p in localparent.props.values() if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
+ result = session.execute(localparent, sql.select([g.columns[0] for g in groupcols], clause, use_labels=True), None)
+ try:
+ row = result.fetchone()
+ for prop in groupcols:
+ if prop is self:
+ continue
+ # set a scalar object instance directly on the object,
+ # bypassing SmartProperty event handlers.
+ sessionlib.attribute_manager.init_instance_attribute(instance, prop.key, uselist=False)
+ instance.__dict__[prop.key] = row[prop.columns[0]]
+ return row[self.columns[0]]
+ finally:
+ result.close()
+ else:
+ return session.scalar(localparent, sql.select([self.columns[0]], clause, use_labels=True),None)
+
+ return lazyload
+
+DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
+
+class DeferredOption(StrategizedOption):
+ def __init__(self, key, defer=False):
+ super(DeferredOption, self).__init__(key)
+ self.defer = defer
+ def get_strategy_class(self):
+ if self.defer:
+ return DeferredColumnLoader
+ else:
+ return ColumnLoader
+
+class AbstractRelationLoader(LoaderStrategy):
+ def init(self):
+ super(AbstractRelationLoader, self).init()
+ self.primaryjoin = self.parent_property.primaryjoin
+ self.secondaryjoin = self.parent_property.secondaryjoin
+ self.secondary = self.parent_property.secondary
+ self.foreignkey = self.parent_property.foreignkey
+ self.mapper = self.parent_property.mapper
+ self.target = self.parent_property.target
+ self.uselist = self.parent_property.uselist
+ self.cascade = self.parent_property.cascade
+ self.attributeext = self.parent_property.attributeext
+ self.order_by = self.parent_property.order_by
+
+ def _init_instance_attribute(self, instance, callable_=None):
+ return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_)
+
+ def _register_attribute(self, class_, callable_=None):
+ self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
+ sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, callable_=callable_)
+
+class NoLoader(AbstractRelationLoader):
+ def process_row(self, selectcontext, instance, row, identitykey, isnew):
+ if isnew:
+ if not self.is_default or len(selectcontext.options):
+ self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key))
+ self._init_instance_attribute(instance)
+
+NoLoader.logger = logging.class_logger(NoLoader)
+
+class LazyLoader(AbstractRelationLoader):
+ def init(self):
+ super(LazyLoader, self).init()
+ (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey)
+ # determine if our "lazywhere" clause is the same as the mapper's
+ # get() clause. then we can just use mapper.get()
+ self.use_get = not self.uselist and self.mapper.query()._get_clause.compare(self.lazywhere)
+
+ def init_class_attribute(self):
+ self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i))
+
+ def setup_loader(self, instance, options=None):
+ if not mapper.has_mapper(instance):
+ return None
+ else:
+ prop = mapper.object_mapper(instance).props[self.key]
+ if prop is not self.parent_property:
+ return prop._get_strategy(LazyLoader).setup_loader(instance)
+ def lazyload():
+ self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+ params = {}
+ allparams = True
+ # if the instance wasnt loaded from the database, then it cannot lazy load
+ # child items. one reason for this is that a bi-directional relationship
+ # will not update properly, since bi-directional uses lazy loading functions
+ # in both directions, and this instance will not be present in the lazily-loaded
+ # results of the other objects since its not in the database
+ if not mapper.has_identity(instance):
+ return None
+ #print "setting up loader, lazywhere", str(self.lazywhere), "binds", self.lazybinds
+ for col, bind in self.lazybinds.iteritems():
+ params[bind.key] = self.parent._getattrbycolumn(instance, col)
+ if params[bind.key] is None:
+ allparams = False
+ break
+
+ if not allparams:
+ return None
+
+ session = sessionlib.object_session(instance)
+ if session is None:
+ try:
+ session = mapper.object_mapper(instance).get_session()
+ except exceptions.InvalidRequestError:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+ # if we have a simple straight-primary key load, use mapper.get()
+ # to possibly save a DB round trip
+ if self.use_get:
+ ident = []
+ for primary_key in self.mapper.pks_by_table[self.mapper.mapped_table]:
+ bind = self.lazyreverse[primary_key]
+ ident.append(params[bind.key])
+ return self.mapper.using(session).get(ident)
+ elif self.order_by is not False:
+ order_by = self.order_by
+ elif self.secondary is not None and self.secondary.default_order_by() is not None:
+ order_by = self.secondary.default_order_by()
+ else:
+ order_by = False
+ result = session.query(self.mapper, with_options=options).select_whereclause(self.lazywhere, order_by=order_by, params=params)
+
+ if self.uselist:
+ return result
+ else:
+ if len(result):
+ return result[0]
+ else:
+ return None
+ return lazyload
+
+ def process_row(self, selectcontext, instance, row, identitykey, isnew):
+ if isnew:
+ # new object instance being loaded from a result row
+ if not self.is_default or len(selectcontext.options):
+ self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
+ # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
+ # which will override the clareset_instance_attributess-level behavior
+ self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options))
+ else:
+ self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
+ # we are the primary manager for this attribute on this class - reset its per-instance attribute state,
+ # so that the class-level lazy loader is executed when next referenced on this instance.
+ # this usually is not needed unless the constructor of the object referenced the attribute before we got
+ # to load data into it.
+ sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
+
+ def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey):
+ binds = {}
+ reverse = {}
+ def column_in_table(table, column):
+ return table.corresponding_column(column, raiseerr=False, keys_ok=False) is not None
+
+ def bind_label():
+ return "lazy_" + hex(random.randint(0, 65535))[2:]
+ def visit_binary(binary):
+ circular = isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column) and binary.left.table is binary.right.table
+ if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column) and ((not circular and column_in_table(table, binary.left)) or (circular and binary.right in foreignkey)):
+ col = binary.left
+ binary.left = binds.setdefault(binary.left,
+ sql.BindParamClause(bind_label(), None, shortname=binary.left.name, type=binary.right.type))
+ reverse[binary.right] = binds[col]
+
+ if isinstance(binary.right, schema.Column) and isinstance(binary.left, schema.Column) and ((not circular and column_in_table(table, binary.right)) or (circular and binary.left in foreignkey)):
+ col = binary.right
+ binary.right = binds.setdefault(binary.right,
+ sql.BindParamClause(bind_label(), None, shortname=binary.right.name, type=binary.left.type))
+ reverse[binary.left] = binds[col]
+
+ lazywhere = primaryjoin.copy_container()
+ li = mapperutil.BinaryVisitor(visit_binary)
+ lazywhere.accept_visitor(li)
+ if secondaryjoin is not None:
+ lazywhere = sql.and_(lazywhere, secondaryjoin)
+ LazyLoader.logger.debug("create_lazy_clause " + str(lazywhere))
+ return (lazywhere, binds, reverse)
+
+LazyLoader.logger = logging.class_logger(LazyLoader)
+
+
+
+class EagerLoader(AbstractRelationLoader):
+ """loads related objects inline with a parent query."""
+ def init(self):
+ super(EagerLoader, self).init()
+ if self.parent.isa(self.mapper):
+ raise exceptions.ArgumentError("Error creating eager relationship '%s' on parent class '%s' to child class '%s': Cant use eager loading on a self referential relationship." % (self.key, repr(self.parent.class_), repr(self.mapper.class_)))
+ self.parent._has_eager = True
+
+ self.clauses = {}
+ self.clauses_by_lead_mapper = {}
+
+ class AliasedClauses(object):
+ """defines a set of join conditions and table aliases which are aliased on a randomly-generated
+ alias name, corresponding to the connection of an optional parent AliasedClauses object and a
+ target mapper.
+
+ EagerLoader has a distinct AliasedClauses object per parent AliasedClauses object,
+ so that all paths from one mapper to another across a chain of eagerloaders generates a distinct
+ chain of joins. The AliasedClauses objects are generated and cached on an as-needed basis.
+
+ e.g.:
+
+ mapper A -->
+ (EagerLoader 'items') -->
+ mapper B -->
+ (EagerLoader 'keywords') -->
+ mapper C
+
+ will generate:
+
+ EagerLoader 'items' --> {
+ None : AliasedClauses(items, None, alias_suffix='AB34') # mappera JOIN mapperb_AB34
+ }
+
+ EagerLoader 'keywords' --> [
+ None : AliasedClauses(keywords, None, alias_suffix='43EF') # mapperb JOIN mapperc_43EF
+ AliasedClauses(items, None, alias_suffix='AB34') :
+ AliasedClauses(keywords, items, alias_suffix='8F44') # mapperb_AB34 JOIN mapperc_8F44
+ ]
+ """
+ def __init__(self, eagerloader, parentclauses=None):
+ self.parent = eagerloader
+ self.target = eagerloader.target
+ self.eagertarget = eagerloader.target.alias()
+ if eagerloader.secondary:
+ self.eagersecondary = eagerloader.secondary.alias()
+ self.aliasizer = sql_util.Aliasizer(eagerloader.target, eagerloader.secondary, aliases={
+ eagerloader.target:self.eagertarget,
+ eagerloader.secondary:self.eagersecondary
+ })
+ self.eagersecondaryjoin = eagerloader.secondaryjoin.copy_container()
+ self.eagersecondaryjoin.accept_visitor(self.aliasizer)
+ self.eagerprimary = eagerloader.primaryjoin.copy_container()
+ self.eagerprimary.accept_visitor(self.aliasizer)
+ else:
+ self.aliasizer = sql_util.Aliasizer(eagerloader.target, aliases={eagerloader.target:self.eagertarget})
+ self.eagerprimary = eagerloader.primaryjoin.copy_container()
+ self.eagerprimary.accept_visitor(self.aliasizer)
+
+ if parentclauses is not None:
+ self.eagerprimary.accept_visitor(parentclauses.aliasizer)
+
+ if eagerloader.order_by:
+ self.eager_order_by = self._aliasize_orderby(eagerloader.order_by)
+ else:
+ self.eager_order_by = None
+
+ self._row_decorator = self._create_decorator_row()
+
+ def _aliasize_orderby(self, orderby, copy=True):
+ if copy:
+ orderby = [o.copy_container() for o in util.to_list(orderby)]
+ else:
+ orderby = util.to_list(orderby)
+ for i in range(0, len(orderby)):
+ if isinstance(orderby[i], schema.Column):
+ orderby[i] = self.eagertarget.corresponding_column(orderby[i])
+ else:
+ orderby[i].accept_visitor(self.aliasizer)
+ return orderby
+
+ def _create_decorator_row(self):
+ class EagerRowAdapter(object):
+ def __init__(self, row):
+ self.row = row
+ def has_key(self, key):
+ return map.has_key(key) or self.row.has_key(key)
+ def __getitem__(self, key):
+ if map.has_key(key):
+ key = map[key]
+ return self.row[key]
+ def keys(self):
+ return map.keys()
+ map = {}
+ for c in self.eagertarget.c:
+ parent = self.target.corresponding_column(c)
+ map[parent] = c
+ map[parent._label] = c
+ map[parent.name] = c
+ return EagerRowAdapter
+
+ def _decorate_row(self, row):
+ # adapts a row at row iteration time to transparently
+ # convert plain columns into the aliased columns that were actually
+ # added to the column clause of the SELECT.
+ return self._row_decorator(row)
+
+ def init_class_attribute(self):
+ self.parent_property._get_strategy(LazyLoader).init_class_attribute()
+
+ def setup_query(self, context, eagertable=None, parentclauses=None, parentmapper=None, **kwargs):
+ """add a left outer join to the statement thats being constructed"""
+ if parentmapper is None:
+ localparent = context.mapper
+ else:
+ localparent = parentmapper
+
+ if self in context.recursion_stack:
+ return
+ else:
+ context.recursion_stack.add(self)
+
+ statement = context.statement
+
+ if hasattr(statement, '_outerjoin'):
+ towrap = statement._outerjoin
+ elif isinstance(localparent.mapped_table, schema.Table):
+ # if the mapper is against a plain Table, look in the from_obj of the select statement
+ # to join against whats already there.
+ for (fromclause, finder) in [(x, sql_util.TableFinder(x)) for x in statement.froms]:
+ # dont join against an Alias'ed Select. we are really looking either for the
+ # table itself or a Join that contains the table. this logic still might need
+ # adjustments for scenarios not thought of yet.
+ if not isinstance(fromclause, sql.Alias) and localparent.mapped_table in finder:
+ towrap = fromclause
+ break
+ else:
+ raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), self.localparent.mapped_table))
+ else:
+ # if the mapper is against a select statement or something, we cant handle that at the
+ # same time as a custom FROM clause right now.
+ towrap = localparent.mapped_table
+
+ try:
+ clauses = self.clauses[parentclauses]
+ except KeyError:
+ clauses = EagerLoader.AliasedClauses(self, parentclauses)
+ self.clauses[parentclauses] = clauses
+ self.clauses_by_lead_mapper[context.mapper] = clauses
+
+ if self.secondaryjoin is not None:
+ statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin)
+ if self.order_by is False and self.secondary.default_order_by() is not None:
+ statement.order_by(*clauses.eagersecondary.default_order_by())
+ else:
+ statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary)
+ if self.order_by is False and clauses.eagertarget.default_order_by() is not None:
+ statement.order_by(*clauses.eagertarget.default_order_by())
+
+ if clauses.eager_order_by:
+ statement.order_by(*util.to_list(clauses.eager_order_by))
+ elif getattr(statement, 'order_by_clause', None):
+ clauses._aliasize_orderby(statement.order_by_clause, False)
+
+ statement.append_from(statement._outerjoin)
+ for value in self.mapper.props.values():
+ value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.mapper)
+
+ def process_row(self, selectcontext, instance, row, identitykey, isnew):
+ """receive a row. tell our mapper to look for a new object instance in the row, and attach
+ it to a list on the parent instance."""
+
+ if self in selectcontext.recursion_stack:
+ return
+
+ try:
+ clauses = self.clauses_by_lead_mapper[selectcontext.mapper]
+ decorated_row = clauses._decorate_row(row)
+ # check for identity key
+ identity_key = self.mapper._row_identity_key(decorated_row)
+ except KeyError:
+ # else degrade to a lazy loader
+ self.logger.debug("degrade to lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
+ self.parent_property._get_strategy(LazyLoader).process_row(selectcontext, instance, row, identitykey, isnew)
+ return
+
+ if not self.uselist:
+ self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
+ if isnew:
+ # set a scalar object instance directly on the parent object,
+ # bypassing SmartProperty event handlers.
+ instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None)
+ else:
+ # call _instance on the row, even though the object has been created,
+ # so that we further descend into properties
+ self.mapper._instance(selectcontext, decorated_row, None)
+ else:
+ if isnew:
+ self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
+ # call the SmartProperty's initialize() method to create a new, blank list
+ l = getattr(instance.__class__, self.key).initialize(instance)
+
+ # create an appender object which will add set-like semantics to the list
+ appender = util.UniqueAppender(l.data)
+
+ # store it in the "scratch" area, which is local to this load operation.
+ selectcontext.attributes[(instance, self.key)] = appender
+ result_list = selectcontext.attributes[(instance, self.key)]
+ self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
+ # TODO: recursion check a speed hit...? try to get a "termination point" into the AliasedClauses
+ # or EagerRowAdapter ?
+ selectcontext.recursion_stack.add(self)
+ try:
+ self.mapper._instance(selectcontext, decorated_row, result_list)
+ finally:
+ selectcontext.recursion_stack.remove(self)
+
+EagerLoader.logger = logging.class_logger(EagerLoader)
+
+class EagerLazyOption(StrategizedOption):
+ def __init__(self, key, lazy=True):
+ super(EagerLazyOption, self).__init__(key)
+ self.lazy = lazy
+ def get_strategy_class(self):
+ if self.lazy:
+ return LazyLoader
+ elif self.lazy is False:
+ return EagerLoader
+ elif self.lazy is None:
+ return NoLoader
+
+
+