diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/ext/baked.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 53 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 111 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 108 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 48 |
9 files changed, 340 insertions, 37 deletions
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index ba3c2aed0..c0fe963ac 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -154,7 +154,7 @@ class BakedQuery(object): self._spoiled = True return self - def _add_lazyload_options(self, options, effective_path): + def _add_lazyload_options(self, options, effective_path, cache_path=None): """Used by per-state lazy loaders to add options to the "lazy load" query from a parent query. @@ -166,13 +166,16 @@ class BakedQuery(object): key = () - if effective_path.path[0].is_aliased_class: + if not cache_path: + cache_path = effective_path + + if cache_path.path[0].is_aliased_class: # paths that are against an AliasedClass are unsafe to cache # with since the AliasedClass is an ad-hoc object. self.spoil() else: for opt in options: - cache_key = opt._generate_cache_key(effective_path) + cache_key = opt._generate_cache_key(cache_path) if cache_key is False: self.spoil() elif cache_key is not None: @@ -181,7 +184,7 @@ class BakedQuery(object): self.add_criteria( lambda q: q._with_current_path(effective_path). _conditional_options(*options), - effective_path.path, key + cache_path.path, key ) def _retrieve_baked_query(self, session): diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index adfe2360a..7ecd5b67e 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -246,6 +246,7 @@ immediateload = strategy_options.immediateload._unbound_fn noload = strategy_options.noload._unbound_fn raiseload = strategy_options.raiseload._unbound_fn defaultload = strategy_options.defaultload._unbound_fn +selectin_polymorphic = strategy_options.selectin_polymorphic._unbound_fn from .strategy_options import Load @@ -268,6 +269,7 @@ def __go(lcls): from .. import util as sa_util from . import dynamic from . import events + from . import loading import inspect as _inspect __all__ = sorted(name for name, obj in lcls.items() diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 7feec660d..48c0db851 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -19,6 +19,7 @@ from . import attributes, exc as orm_exc from ..sql import util as sql_util from . import strategy_options from . import path_registry +from .. import sql from .util import _none_set, state_str from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE @@ -353,6 +354,27 @@ def _instance_processor( session_id = context.session.hash_key version_check = context.version_check runid = context.runid + + if not refresh_state and _polymorphic_from is not None: + key = ('loader', path.path) + if ( + key in context.attributes and + context.attributes[key].strategy == + (('selectinload_polymorphic', True), ) and + mapper in context.attributes[key].local_opts['mappers'] + ) or mapper.polymorphic_load == 'selectin': + + # only_load_props goes w/ refresh_state only, and in a refresh + # we are a single row query for the exact entity; polymorphic + # loading does not apply + assert only_load_props is None + + callable_ = _load_subclass_via_in(context, path, mapper) + + PostLoad.callable_for_path( + context, load_path, mapper, + callable_, mapper) + post_load = PostLoad.for_context(context, load_path, only_load_props) if refresh_state: @@ -501,6 +523,37 @@ def _instance_processor( return _instance +@util.dependencies("sqlalchemy.ext.baked") +def _load_subclass_via_in(baked, context, path, mapper): + + zero_idx = len(mapper.base_mapper.primary_key) == 1 + + q, enable_opt, disable_opt = mapper._subclass_load_via_in + + def do_load(context, path, states, load_only, effective_entity): + orig_query = context.query + + q._add_lazyload_options( + (enable_opt, ) + orig_query._with_options + (disable_opt, ), + path.parent, cache_path=path + ) + + if orig_query._populate_existing: + q.add_criteria( + lambda q: q.populate_existing() + ) + + q(context.session).params( + primary_keys=[ + state.key[1][0] if zero_idx else state.key[1] + for state, load_attrs in states + if state.mapper.isa(mapper) + ] + ).all() + + return do_load + + def _populate_full( context, row, state, dict_, isnew, load_path, loaded_instance, populate_existing, populators): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 6bf86d0ef..1042442c0 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -106,6 +106,7 @@ class Mapper(InspectionAttr): polymorphic_identity=None, concrete=False, with_polymorphic=None, + polymorphic_load=None, allow_partial_pks=True, batch=True, column_prefix=None, @@ -381,6 +382,27 @@ class Mapper(InspectionAttr): :paramref:`.mapper.passive_deletes` - supporting ON DELETE CASCADE for joined-table inheritance mappers + :param polymorphic_load: Specifies "polymorphic loading" behavior + for a subclass in an inheritance hierarchy (joined and single + table inheritance only). Valid values are: + + * "'inline'" - specifies this class should be part of the + "with_polymorphic" mappers, e.g. its columns will be included + in a SELECT query against the base. + + * "'selectin'" - specifies that when instances of this class + are loaded, an additional SELECT will be emitted to retrieve + the columns specific to this subclass. The SELECT uses + IN to fetch multiple subclasses at once. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`with_polymorphic_mapper_config` + + :ref:`polymorphic_selectin` + :param polymorphic_on: Specifies the column, attribute, or SQL expression used to determine the target class for an incoming row, when inheriting classes are present. @@ -622,8 +644,6 @@ class Mapper(InspectionAttr): else: self.confirm_deleted_rows = confirm_deleted_rows - self._set_with_polymorphic(with_polymorphic) - if isinstance(self.local_table, expression.SelectBase): raise sa_exc.InvalidRequestError( "When mapping against a select() construct, map against " @@ -632,11 +652,8 @@ class Mapper(InspectionAttr): "SELECT from a subquery that does not have an alias." ) - if self.with_polymorphic and \ - isinstance(self.with_polymorphic[1], - expression.SelectBase): - self.with_polymorphic = (self.with_polymorphic[0], - self.with_polymorphic[1].alias()) + self._set_with_polymorphic(with_polymorphic) + self.polymorphic_load = polymorphic_load # our 'polymorphic identity', a string name that when located in a # result set row indicates this Mapper should be used to construct @@ -1037,6 +1054,19 @@ class Mapper(InspectionAttr): ) self.polymorphic_map[self.polymorphic_identity] = self + if self.polymorphic_load and self.concrete: + raise exc.ArgumentError( + "polymorphic_load is not currently supported " + "with concrete table inheritance") + if self.polymorphic_load == 'inline': + self.inherits._add_with_polymorphic_subclass(self) + elif self.polymorphic_load == 'selectin': + pass + elif self.polymorphic_load is not None: + raise sa_exc.ArgumentError( + "unknown argument for polymorphic_load: %r" % + self.polymorphic_load) + else: self._all_tables = set() self.base_mapper = self @@ -1077,9 +1107,22 @@ class Mapper(InspectionAttr): expression.SelectBase): self.with_polymorphic = (self.with_polymorphic[0], self.with_polymorphic[1].alias()) + if self.configured: self._expire_memoizations() + def _add_with_polymorphic_subclass(self, mapper): + subcl = mapper.class_ + if self.with_polymorphic is None: + self._set_with_polymorphic((subcl,)) + elif self.with_polymorphic[0] != '*': + self._set_with_polymorphic( + ( + self.with_polymorphic[0] + (subcl, ), + self.with_polymorphic[1] + ) + ) + def _set_concrete_base(self, mapper): """Set the given :class:`.Mapper` as the 'inherits' for this :class:`.Mapper`, assuming this :class:`.Mapper` is concrete @@ -2663,6 +2706,60 @@ class Mapper(InspectionAttr): cols.extend(props[key].columns) return sql.select(cols, cond, use_labels=True) + @_memoized_configured_property + @util.dependencies( + "sqlalchemy.ext.baked", + "sqlalchemy.orm.strategy_options") + def _subclass_load_via_in(self, baked, strategy_options): + """Assemble a BakedQuery that can load the columns local to + this subclass as a SELECT with IN. + + """ + assert self.inherits + + polymorphic_prop = self._columntoproperty[ + self.polymorphic_on] + keep_props = set( + [polymorphic_prop] + self._identity_key_props) + + disable_opt = strategy_options.Load(self) + enable_opt = strategy_options.Load(self) + + for prop in self.attrs: + if prop.parent is self or prop in keep_props: + # "enable" options, to turn on the properties that we want to + # load by default (subject to options from the query) + enable_opt.set_generic_strategy( + (prop.key, ), + dict(prop.strategy_key) + ) + else: + # "disable" options, to turn off the properties from the + # superclass that we *don't* want to load, applied after + # the options from the query to override them + disable_opt.set_generic_strategy( + (prop.key, ), + {"do_nothing": True} + ) + + if len(self.primary_key) > 1: + in_expr = sql.tuple_(*self.primary_key) + else: + in_expr = self.primary_key[0] + + q = baked.BakedQuery( + self._compiled_cache, + lambda session: session.query(self), + (self, ) + ) + q += lambda q: q.filter( + in_expr.in_( + sql.bindparam('primary_keys', expanding=True) + ) + ).order_by(*self.primary_key) + + return q, enable_opt, disable_opt + def cascade_iterator(self, type_, state, halt_on=None): """Iterate each element and its mapper in an object graph, for all relationships that meet the given cascade rule. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index dc69ae99d..e48462d35 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -196,6 +196,7 @@ class ColumnLoader(LoaderStrategy): @log.class_logger @properties.ColumnProperty.strategy_for(deferred=True, instrument=True) +@properties.ColumnProperty.strategy_for(do_nothing=True) class DeferredColumnLoader(LoaderStrategy): """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" @@ -336,6 +337,18 @@ class AbstractRelationshipLoader(LoaderStrategy): @log.class_logger +@properties.RelationshipProperty.strategy_for(do_nothing=True) +class DoNothingLoader(LoaderStrategy): + """Relationship loader that makes no change to the object's state. + + Compared to NoLoader, this loader does not initialize the + collection/attribute to empty/none; the usual default LazyLoader will + take effect. + + """ + + +@log.class_logger @properties.RelationshipProperty.strategy_for(lazy="noload") @properties.RelationshipProperty.strategy_for(lazy=None) class NoLoader(AbstractRelationshipLoader): @@ -711,6 +724,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): self, context, path, loadopt, mapper, result, adapter, populators): key = self.key + if not self.is_class_level: # we are not the primary manager for this attribute # on this class - set up a @@ -1804,6 +1818,9 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): selectin_path = ( context.query._current_path or orm_util.PathRegistry.root) + path + if not orm_util._entity_isa(path[-1], self.parent): + return + if loading.PostLoad.path_exists(context, selectin_path, self.key): return @@ -1914,6 +1931,7 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): } for key, state, overwrite in chunk: + if not overwrite and self.key in state.dict: continue diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index df13f05db..d3f456969 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -13,7 +13,7 @@ from .attributes import QueryableAttribute from .. import util from ..sql.base import _generative, Generative from .. import exc as sa_exc, inspect -from .base import _is_aliased_class, _class_to_mapper +from .base import _is_aliased_class, _class_to_mapper, _is_mapped_class from . import util as orm_util from .path_registry import PathRegistry, TokenRegistry, \ _WILDCARD_TOKEN, _DEFAULT_TOKEN @@ -63,6 +63,7 @@ class Load(Generative, MapperOption): self.context = util.OrderedDict() self.local_opts = {} self._of_type = None + self.is_class_strategy = False @classmethod def for_existing_path(cls, path): @@ -127,6 +128,7 @@ class Load(Generative, MapperOption): return cloned is_opts_only = False + is_class_strategy = False strategy = None propagate_to_loaders = False @@ -148,6 +150,7 @@ class Load(Generative, MapperOption): def _generate_path(self, path, attr, wildcard_key, raiseerr=True): self._of_type = None + if raiseerr and not path.has_entity: if isinstance(path, TokenRegistry): raise sa_exc.ArgumentError( @@ -187,6 +190,14 @@ class Load(Generative, MapperOption): attr = attr.property path = path[attr] + elif _is_mapped_class(attr): + if not attr.common_parent(path.mapper): + if raiseerr: + raise sa_exc.ArgumentError( + "Attribute '%s' does not " + "link from element '%s'" % (attr, path.entity)) + else: + return None else: prop = attr.property @@ -246,6 +257,7 @@ class Load(Generative, MapperOption): self, attr, strategy, propagate_to_loaders=True): strategy = self._coerce_strat(strategy) + self.is_class_strategy = False self.propagate_to_loaders = propagate_to_loaders # if the path is a wildcard, this will set propagate_to_loaders=False self._generate_path(self.path, attr, "relationship") @@ -257,6 +269,7 @@ class Load(Generative, MapperOption): def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False): strategy = self._coerce_strat(strategy) + self.is_class_strategy = False for attr in attrs: cloned = self._generate() cloned.strategy = strategy @@ -267,6 +280,31 @@ class Load(Generative, MapperOption): if opts_only: cloned.is_opts_only = True cloned._set_path_strategy() + self.is_class_strategy = False + + @_generative + def set_generic_strategy(self, attrs, strategy): + strategy = self._coerce_strat(strategy) + + for attr in attrs: + path = self._generate_path(self.path, attr, None) + cloned = self._generate() + cloned.strategy = strategy + cloned.path = path + cloned.propagate_to_loaders = True + cloned._set_path_strategy() + + @_generative + def set_class_strategy(self, strategy, opts): + strategy = self._coerce_strat(strategy) + cloned = self._generate() + cloned.is_class_strategy = True + path = cloned._generate_path(self.path, None, None) + cloned.strategy = strategy + cloned.path = path + cloned.propagate_to_loaders = True + cloned._set_path_strategy() + cloned.local_opts.update(opts) def _set_for_path(self, context, path, replace=True, merge_opts=False): if merge_opts or not replace: @@ -284,7 +322,7 @@ class Load(Generative, MapperOption): self.local_opts.update(existing.local_opts) def _set_path_strategy(self): - if self.path.has_entity: + if not self.is_class_strategy and self.path.has_entity: effective_path = self.path.parent else: effective_path = self.path @@ -367,7 +405,10 @@ class _UnboundLoad(Load): if attr == _DEFAULT_TOKEN: self.propagate_to_loaders = False attr = "%s:%s" % (wildcard_key, attr) - path = path + (attr, ) + if path and _is_mapped_class(path[-1]) and not self.is_class_strategy: + path = path[0:-1] + if attr: + path = path + (attr, ) self.path = path return path @@ -502,7 +543,12 @@ class _UnboundLoad(Load): (User, User.orders.property, Order, Order.items.property)) """ + start_path = self.path + + if self.is_class_strategy and current_path: + start_path += (entities[0], ) + # _current_path implies we're in a # secondary load with an existing path @@ -517,7 +563,8 @@ class _UnboundLoad(Load): token = start_path[0] if isinstance(token, util.string_types): - entity = self._find_entity_basestring(entities, token, raiseerr) + entity = self._find_entity_basestring( + entities, token, raiseerr) elif isinstance(token, PropComparator): prop = token.property entity = self._find_entity_prop_comparator( @@ -525,7 +572,10 @@ class _UnboundLoad(Load): prop.key, token._parententity, raiseerr) - + elif self.is_class_strategy and _is_mapped_class(token): + entity = inspect(token) + if entity not in entities: + entity = None else: raise sa_exc.ArgumentError( "mapper option expects " @@ -541,7 +591,6 @@ class _UnboundLoad(Load): # we just located, then go through the rest of our path # tokens and populate into the Load(). loader = Load(path_element) - if context is not None: loader.context = context else: @@ -549,16 +598,19 @@ class _UnboundLoad(Load): loader.strategy = self.strategy loader.is_opts_only = self.is_opts_only + loader.is_class_strategy = self.is_class_strategy path = loader.path - for token in start_path: - if not loader._generate_path( - loader.path, token, None, raiseerr): - return + + if not loader.is_class_strategy: + for token in start_path: + if not loader._generate_path( + loader.path, token, None, raiseerr): + return loader.local_opts.update(self.local_opts) - if loader.path.has_entity: + if not loader.is_class_strategy and loader.path.has_entity: effective_path = loader.path.parent else: effective_path = loader.path @@ -1289,3 +1341,37 @@ def undefer_group(loadopt, name): @undefer_group._add_unbound_fn def undefer_group(name): return _UnboundLoad().undefer_group(name) + + +@loader_option() +def selectin_polymorphic(loadopt, classes): + """Indicate an eager load should take place for all attributes + specific to a subclass. + + This uses an additional SELECT with IN against all matched primary + key values, and is the per-query analogue to the ``"selectin"`` + setting on the :paramref:`.mapper.polymorphic_load` parameter. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`inheritance_polymorphic_load` + + """ + loadopt.set_class_strategy( + {"selectinload_polymorphic": True}, + opts={"mappers": tuple(sorted((inspect(cls) for cls in classes), key=id))} + ) + return loadopt + + +@selectin_polymorphic._add_unbound_fn +def selectin_polymorphic(base_cls, classes): + ul = _UnboundLoad() + ul.is_class_strategy = True + ul.path = (inspect(base_cls), ) + ul.selectin_polymorphic( + classes + ) + return ul diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 9a397ccf3..4267b79fb 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1043,7 +1043,13 @@ def was_deleted(object): state = attributes.instance_state(object) return state.was_deleted + def _entity_corresponds_to(given, entity): + """determine if 'given' corresponds to 'entity', in terms + of an entity passed to Query that would match the same entity + being referred to elsewhere in the query. + + """ if entity.is_aliased_class: if given.is_aliased_class: if entity._base_alias is given._base_alias: @@ -1057,6 +1063,21 @@ def _entity_corresponds_to(given, entity): return entity.common_parent(given) + +def _entity_isa(given, mapper): + """determine if 'given' "is a" mapper, in terms of the given + would load rows of type 'mapper'. + + """ + if given.is_aliased_class: + return mapper in given.with_polymorphic_mappers or \ + given.mapper.isa(mapper) + elif given.with_polymorphic_mappers: + return mapper in given.with_polymorphic_mappers + else: + return given.isa(mapper) + + def randomize_unitofwork(): """Use random-ordering sets within the unit of work in order to detect unit of work sorting issues. diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index dfea33dc7..c0854ea55 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -497,8 +497,9 @@ class AssertsExecutionResults(object): def assert_sql_execution(self, db, callable_, *rules): with self.sql_execution_asserter(db) as asserter: - callable_() + result = callable_() asserter.assert_(*rules) + return result def assert_sql(self, db, callable_, rules): @@ -512,7 +513,7 @@ class AssertsExecutionResults(object): newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) - self.assert_sql_execution(db, callable_, *newrules) + return self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e39b6315d..86d850733 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -282,6 +282,32 @@ class AllOf(AssertRule): self.errormessage = list(self.rules)[0].errormessage +class EachOf(AssertRule): + + def __init__(self, *rules): + self.rules = list(rules) + + def process_statement(self, execute_observed): + while self.rules: + rule = self.rules[0] + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.pop(0) + elif rule.errormessage: + self.errormessage = rule.errormessage + if rule.consume_statement: + break + + if not self.rules: + self.is_consumed = True + + def no_more_statements(self): + if self.rules and not self.rules[0].is_consumed: + self.rules[0].no_more_statements() + elif self.rules: + super(EachOf, self).no_more_statements() + + class Or(AllOf): def process_statement(self, execute_observed): @@ -319,24 +345,20 @@ class SQLAsserter(object): del self.accumulated def assert_(self, *rules): - rules = list(rules) - observed = list(self._final) + rule = EachOf(*rules) - while observed and rules: - rule = rules[0] - rule.process_statement(observed[0]) + observed = list(self._final) + while observed: + statement = observed.pop(0) + rule.process_statement(statement) if rule.is_consumed: - rules.pop(0) + break elif rule.errormessage: assert False, rule.errormessage - - if rule.consume_statement: - observed.pop(0) - - if not observed and rules: - rules[0].no_more_statements() - elif not rules and observed: + if observed: assert False, "Additional SQL statements remain" + elif not rule.is_consumed: + rule.no_more_statements() @contextlib.contextmanager |
