diff options
Diffstat (limited to 'lib/sqlalchemy/orm/evaluator.py')
| -rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 101 |
1 files changed, 71 insertions, 30 deletions
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 4abf08ab1..ac031d84f 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -14,17 +14,40 @@ from .. import util class UnevaluatableError(Exception): pass -_straight_ops = set(getattr(operators, op) - for op in ('add', 'mul', 'sub', - 'div', - 'mod', 'truediv', - 'lt', 'le', 'ne', 'gt', 'ge', 'eq')) - -_notimplemented_ops = set(getattr(operators, op) - for op in ('like_op', 'notlike_op', 'ilike_op', - 'notilike_op', 'between_op', 'in_op', - 'notin_op', 'endswith_op', 'concat_op')) +_straight_ops = set( + getattr(operators, op) + for op in ( + "add", + "mul", + "sub", + "div", + "mod", + "truediv", + "lt", + "le", + "ne", + "gt", + "ge", + "eq", + ) +) + + +_notimplemented_ops = set( + getattr(operators, op) + for op in ( + "like_op", + "notlike_op", + "ilike_op", + "notilike_op", + "between_op", + "in_op", + "notin_op", + "endswith_op", + "concat_op", + ) +) class EvaluatorCompiler(object): @@ -35,7 +58,8 @@ class EvaluatorCompiler(object): meth = getattr(self, "visit_%s" % clause.__visit_name__, None) if not meth: raise UnevaluatableError( - "Cannot evaluate %s" % type(clause).__name__) + "Cannot evaluate %s" % type(clause).__name__ + ) return meth(clause) def visit_grouping(self, clause): @@ -51,28 +75,30 @@ class EvaluatorCompiler(object): return lambda obj: True def visit_column(self, clause): - if 'parentmapper' in clause._annotations: - parentmapper = clause._annotations['parentmapper'] + if "parentmapper" in clause._annotations: + parentmapper = clause._annotations["parentmapper"] if self.target_cls and not issubclass( - self.target_cls, parentmapper.class_): + self.target_cls, parentmapper.class_ + ): raise UnevaluatableError( - "Can't evaluate criteria against alternate class %s" % - parentmapper.class_ + "Can't evaluate criteria against alternate class %s" + % parentmapper.class_ ) key = parentmapper._columntoproperty[clause].key else: key = clause.key - if self.target_cls and \ - key in inspect(self.target_cls).column_attrs: + if ( + self.target_cls + and key in inspect(self.target_cls).column_attrs + ): util.warn( "Evaluating non-mapped column expression '%s' onto " "ORM instances; this is a deprecated use case. Please " "make use of the actual mapped columns in ORM-evaluated " - "UPDATE / DELETE expressions." % clause) - else: - raise UnevaluatableError( - "Cannot evaluate column: %s" % clause + "UPDATE / DELETE expressions." % clause ) + else: + raise UnevaluatableError("Cannot evaluate column: %s" % clause) get_corresponding_attr = operator.attrgetter(key) return lambda obj: get_corresponding_attr(obj) @@ -80,6 +106,7 @@ class EvaluatorCompiler(object): def visit_clauselist(self, clause): evaluators = list(map(self.process, clause.clauses)) if clause.operator is operators.or_: + def evaluate(obj): has_null = False for sub_evaluate in evaluators: @@ -90,7 +117,9 @@ class EvaluatorCompiler(object): if has_null: return None return False + elif clause.operator is operators.and_: + def evaluate(obj): for sub_evaluate in evaluators: value = sub_evaluate(obj) @@ -99,48 +128,60 @@ class EvaluatorCompiler(object): return None return False return True + else: raise UnevaluatableError( - "Cannot evaluate clauselist with operator %s" % - clause.operator) + "Cannot evaluate clauselist with operator %s" % clause.operator + ) return evaluate def visit_binary(self, clause): - eval_left, eval_right = list(map(self.process, - [clause.left, clause.right])) + eval_left, eval_right = list( + map(self.process, [clause.left, clause.right]) + ) operator = clause.operator if operator is operators.is_: + def evaluate(obj): return eval_left(obj) == eval_right(obj) + elif operator is operators.isnot: + def evaluate(obj): return eval_left(obj) != eval_right(obj) + elif operator in _straight_ops: + def evaluate(obj): left_val = eval_left(obj) right_val = eval_right(obj) if left_val is None or right_val is None: return None return operator(eval_left(obj), eval_right(obj)) + else: raise UnevaluatableError( - "Cannot evaluate %s with operator %s" % - (type(clause).__name__, clause.operator)) + "Cannot evaluate %s with operator %s" + % (type(clause).__name__, clause.operator) + ) return evaluate def visit_unary(self, clause): eval_inner = self.process(clause.element) if clause.operator is operators.inv: + def evaluate(obj): value = eval_inner(obj) if value is None: return None return not value + return evaluate raise UnevaluatableError( - "Cannot evaluate %s with operator %s" % - (type(clause).__name__, clause.operator)) + "Cannot evaluate %s with operator %s" + % (type(clause).__name__, clause.operator) + ) def visit_bindparam(self, clause): if clause.callable: |
