diff options
Diffstat (limited to 'lib/sqlalchemy/ext/associationproxy.py')
-rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 30 |
1 files changed, 26 insertions, 4 deletions
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 0482a9205..fca2f0008 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -17,7 +17,7 @@ import operator import weakref from .. import exc, orm, util from ..orm import collections, interfaces -from ..sql import not_ +from ..sql import not_, or_ def association_proxy(target_collection, attr, **kw): @@ -231,6 +231,10 @@ class AssociationProxy(interfaces._InspectionAttr): return not self._get_property().\ mapper.get_property(self.value_attr).uselist + @util.memoized_property + def _target_is_object(self): + return getattr(self.target_class, self.value_attr).impl.uses_objects + def __get__(self, obj, class_): if self.owning_class is None: self.owning_class = class_ and class_ or type(obj) @@ -388,10 +392,17 @@ class AssociationProxy(interfaces._InspectionAttr): """ - return self._comparator.has( + if self._target_is_object: + return self._comparator.has( getattr(self.target_class, self.value_attr).\ has(criterion, **kwargs) ) + else: + if criterion is not None or kwargs: + raise exc.ArgumentError( + "Non-empty has() not allowed for " + "column-targeted association proxy; use ==") + return self._comparator.has() def contains(self, obj): """Produce a proxied 'contains' expression using EXISTS. @@ -411,10 +422,21 @@ class AssociationProxy(interfaces._InspectionAttr): return self._comparator.any(**{self.value_attr: obj}) def __eq__(self, obj): - return self._comparator.has(**{self.value_attr: obj}) + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + if obj is None: + return or_( + self._comparator.has(**{self.value_attr: obj}), + self._comparator == None + ) + else: + return self._comparator.has(**{self.value_attr: obj}) def __ne__(self, obj): - return not_(self.__eq__(obj)) + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + return self._comparator.has( + getattr(self.target_class, self.value_attr) != obj) class _lazy_collection(object): |