From 73f734bf80166c7dfce4892941752d7569a17524 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 6 Feb 2012 12:20:15 -0500 Subject: initial annotations approach to join conditions. all tests pass, plus additional tests in #1401 pass. would now like to reorganize RelationshipProperty more around the annotations concept. --- lib/sqlalchemy/sql/expression.py | 2 +- lib/sqlalchemy/sql/util.py | 48 ++++++++++++++++++++++++++-------------- lib/sqlalchemy/sql/visitors.py | 6 ++--- 3 files changed, 35 insertions(+), 21 deletions(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b11e5ad42..30e19bc68 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -2184,7 +2184,7 @@ class ColumnElement(ClauseElement, _CompareMixin): for oth in to_compare: if use_proxies and self.shares_lineage(oth): return True - elif oth is self: + elif hash(oth) == hash(self): return True else: return False diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 97975441e..f0509c16f 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -225,7 +225,8 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) -def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None): +def join_condition(a, b, ignore_nonexistent_tables=False, + a_subset=None): """create a join condition between two tables or selectables. e.g.:: @@ -535,6 +536,10 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, "'consider_as_foreign_keys' or " "'consider_as_referenced_keys'") + def col_is(a, b): + #return a is b + return a.compare(b) + def visit_binary(binary): if not any_operator and binary.operator is not operators.eq: return @@ -544,20 +549,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ - (binary.right is binary.left or + (col_is(binary.right, binary.left) or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ - (binary.left is binary.right or + (col_is(binary.left, binary.right) or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: @@ -669,11 +674,22 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False): + def __init__(self, selectable, equivalents=None, + include=None, exclude=None, + include_fn=None, exclude_fn=None, + adapt_on_names=False): self.__traverse_options__ = {'stop_on':[selectable]} self.selectable = selectable - self.include = include - self.exclude = exclude + if include: + assert not include_fn + self.include_fn = lambda e: e in include + else: + self.include_fn = include_fn + if exclude: + assert not exclude_fn + self.exclude_fn = lambda e: e in exclude + else: + self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names @@ -693,19 +709,17 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return newcol def replace(self, col): - if isinstance(col, expression.FromClause): - if self.selectable.is_derived_from(col): + if isinstance(col, expression.FromClause) and \ + self.selectable.is_derived_from(col): return self.selectable - - if not isinstance(col, expression.ColumnElement): + elif not isinstance(col, expression.ColumnElement): return None - - if self.include and col not in self.include: + elif self.include_fn and not self.include_fn(col): return None - elif self.exclude and col in self.exclude: + elif self.exclude_fn and self.exclude_fn(col): return None - - return self._corresponding_column(col, True) + else: + return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index cdcf40aa8..75e099f0d 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -240,16 +240,16 @@ def replacement_traverse(obj, opts, replace): replacement by a given replacement function.""" cloned = util.column_dict() - stop_on = util.column_set(opts.get('stop_on', [])) + stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])]) def clone(elem, **kw): - if elem in stop_on or \ + if id(elem) in stop_on or \ 'no_replacement_traverse' in elem._annotations: return elem else: newelem = replace(elem) if newelem is not None: - stop_on.add(newelem) + stop_on.add(id(newelem)) return newelem else: if elem not in cloned: -- cgit v1.2.1 From d1414ad20524c421aa78272c03dce5f839a0aab6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 8 Feb 2012 10:14:36 -0500 Subject: simplify remote annotation significantly, and also catch the actual remote columns more accurately. --- lib/sqlalchemy/sql/expression.py | 4 ++++ lib/sqlalchemy/sql/operators.py | 5 +++++ 2 files changed, 9 insertions(+) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 30e19bc68..72099a5f5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3384,6 +3384,10 @@ class _BinaryExpression(ColumnElement): except: raise TypeError("Boolean value of this clause is not defined") + @property + def is_comparison(self): + return operators.is_comparison(self.operator) + @property def _from_objects(self): return self.left._from_objects + self.right._from_objects diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 89f0aaee1..b86b50db4 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -521,6 +521,11 @@ def nullslast_op(a): _commutative = set([eq, ne, add, mul]) +_comparison = set([eq, ne, lt, gt, ge, le]) + +def is_comparison(op): + return op in _comparison + def is_commutative(op): return op in _commutative -- cgit v1.2.1 From bc45fa350a02da5f24d866078abed471cd98f15b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 9 Feb 2012 21:16:53 -0500 Subject: - got m2m, local_remote_pairs, etc. working - using new traversal that returns the product of both sides of a binary, starting to work with (a+b) == (c+d) types of joins. primaryjoins on functions working - annotations working, including reversing local/remote when doing backref --- lib/sqlalchemy/sql/expression.py | 24 +++++++++---- lib/sqlalchemy/sql/util.py | 76 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 88 insertions(+), 12 deletions(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 72099a5f5..ebf4de9a2 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1576,18 +1576,30 @@ class ClauseElement(Visitable): return id(self) def _annotate(self, values): - """return a copy of this ClauseElement with the given annotations - dictionary. + """return a copy of this ClauseElement with annotations + updated by the given dictionary. """ return sqlutil.Annotated(self, values) - def _deannotate(self): - """return a copy of this ClauseElement with an empty annotations - dictionary. + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. """ - return self._clone() + return sqlutil.Annotated(self, values) + + def _deannotate(self, values=None): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + # since we have no annotations we return + # self + return self def unique_params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elments replaced. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index f0509c16f..9a45a5777 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -62,6 +62,61 @@ def find_join_source(clauses, join_to): else: return None, None + +def visit_binary_product(fn, expr): + """Produce a traversal of the given expression, delivering + column comparisons to the given function. + + The function is of the form:: + + def my_fn(binary, left, right) + + For each binary expression located which has a + comparison operator, the product of "left" and + "right" will be delivered to that function, + in terms of that binary. + + Hence an expression like:: + + and_( + (a + b) == q + func.sum(e + f), + j == r + ) + + would have the traversal:: + + a q + a e + a f + b q + b e + b f + j r + + That is, every combination of "left" and + "right" that doesn't further contain + a binary comparison is passed as pairs. + + """ + stack = [] + def visit(element): + if element.__visit_name__ == 'binary' and \ + operators.is_comparison(element.operator): + stack.insert(0, element) + for l in visit(element.left): + for r in visit(element.right): + fn(stack[0], l, r) + stack.pop(0) + for elem in element.get_children(): + visit(elem) + else: + if isinstance(element, expression.ColumnClause): + yield element + for elem in element.get_children(): + for e in visit(elem): + yield e + list(visit(expr)) + def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False, include_crud=False): @@ -357,13 +412,22 @@ class Annotated(object): def _annotate(self, values): _values = self._annotations.copy() _values.update(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() - clone._annotations = _values + clone._annotations = values return clone - def _deannotate(self): - return self.__element + def _deannotate(self, values=None): + if values is None: + return self.__element + else: + _values = self._annotations.copy() + for v in values: + _values.pop(v, None) + return self._with_annotations(_values) def _compiler_dispatch(self, visitor, **kw): return self.__element.__class__._compiler_dispatch(self, visitor, **kw) @@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None): element = clone(element) return element -def _deep_deannotate(element): - """Deep copy the given element, removing all annotations.""" +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" def clone(elem): - elem = elem._deannotate() + elem = elem._deannotate(values=values) elem._copy_internals(clone=clone) return elem -- cgit v1.2.1 From 7d693180be8c7f9db79831351751a15d786b86a7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 10 Feb 2012 17:59:06 -0500 Subject: tweak for correlated subqueries here, seems to work for test_eager_relations:CorrelatedSubqueryTest but need some more testing here --- lib/sqlalchemy/sql/util.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 9a45a5777..511a5b0c2 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -100,7 +100,12 @@ def visit_binary_product(fn, expr): """ stack = [] def visit(element): - if element.__visit_name__ == 'binary' and \ + if isinstance(element, (expression.FromClause, + expression._ScalarSelect)): + # we dont want to dig into correlated subqueries, + # those are just column elements by themselves + yield element + elif element.__visit_name__ == 'binary' and \ operators.is_comparison(element.operator): stack.insert(0, element) for l in visit(element.left): -- cgit v1.2.1 From 0634ea79b1a23a8b88c886a8a3f434ed300691e2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 11 Feb 2012 15:43:05 -0500 Subject: many fixes but still can't get heuristics to work as well as what's existing, tests still failing --- lib/sqlalchemy/sql/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 511a5b0c2..e4e2c00e1 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -100,8 +100,7 @@ def visit_binary_product(fn, expr): """ stack = [] def visit(element): - if isinstance(element, (expression.FromClause, - expression._ScalarSelect)): + if isinstance(element, (expression._ScalarSelect)): # we dont want to dig into correlated subqueries, # those are just column elements by themselves yield element -- cgit v1.2.1 From d934ea23e24880a5c784c9e5edf9ead5bc965a83 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 11 Feb 2012 20:33:56 -0500 Subject: - figured out again why deannotate must clone() - got everything working. just need to update error strings --- lib/sqlalchemy/sql/expression.py | 13 +++++++++---- lib/sqlalchemy/sql/util.py | 4 ++-- lib/sqlalchemy/sql/visitors.py | 6 +++--- 3 files changed, 14 insertions(+), 9 deletions(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index ebf4de9a2..573ace47f 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1589,7 +1589,7 @@ class ClauseElement(Visitable): """ return sqlutil.Annotated(self, values) - def _deannotate(self, values=None): + def _deannotate(self, values=None, clone=False): """return a copy of this :class:`.ClauseElement` with annotations removed. @@ -1597,9 +1597,14 @@ class ClauseElement(Visitable): to remove. """ - # since we have no annotations we return - # self - return self + if clone: + # clone is used when we are also copying + # the expression for a deep deannotation + return self._clone() + else: + # if no clone, since we have no annotations we return + # self + return self def unique_params(self, *optionaldict, **kwargs): """Return a copy with :func:`bindparam()` elments replaced. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index e4e2c00e1..2862e9af9 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -424,7 +424,7 @@ class Annotated(object): clone._annotations = values return clone - def _deannotate(self, values=None): + def _deannotate(self, values=None, clone=True): if values is None: return self.__element else: @@ -498,7 +498,7 @@ def _deep_deannotate(element, values=None): """Deep copy the given element, removing annotations.""" def clone(elem): - elem = elem._deannotate(values=values) + elem = elem._deannotate(values=values, clone=True) elem._copy_internals(clone=clone) return elem diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 75e099f0d..cd178b716 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -222,13 +222,13 @@ def cloned_traverse(obj, opts, visitors): if elem in stop_on: return elem else: - if elem not in cloned: - cloned[elem] = newelem = elem._clone() + if id(elem) not in cloned: + cloned[id(elem)] = newelem = elem._clone() newelem._copy_internals(clone=clone) meth = visitors.get(newelem.__visit_name__, None) if meth: meth(newelem) - return cloned[elem] + return cloned[id(elem)] if obj is not None: obj = clone(obj) -- cgit v1.2.1 From d37320306560c3d758ed65563d53aa9500095a49 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 25 Feb 2012 17:10:06 -0500 Subject: start to work on error messages, allow foreign_keys as only argument if otherwise can't determine join condition due to no fks --- lib/sqlalchemy/sql/util.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 2862e9af9..38d95dde5 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -284,8 +284,10 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) + def join_condition(a, b, ignore_nonexistent_tables=False, - a_subset=None): + a_subset=None, + consider_as_foreign_keys=None): """create a join condition between two tables or selectables. e.g.:: @@ -321,6 +323,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, for fk in sorted( b.foreign_keys, key=lambda fk:fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue try: col = fk.get_referent(left) except exc.NoReferenceError, nrte: @@ -336,6 +341,9 @@ def join_condition(a, b, ignore_nonexistent_tables=False, for fk in sorted( left.foreign_keys, key=lambda fk:fk.parent._creation_order): + if consider_as_foreign_keys is not None and \ + fk.parent not in consider_as_foreign_keys: + continue try: col = fk.get_referent(b) except exc.NoReferenceError, nrte: @@ -358,11 +366,11 @@ def join_condition(a, b, ignore_nonexistent_tables=False, "subquery using alias()?" else: hint = "" - raise exc.ArgumentError( + raise exc.NoForeignKeysError( "Can't find any foreign key relationships " "between '%s' and '%s'.%s" % (a.description, b.description, hint)) elif len(constraints) > 1: - raise exc.ArgumentError( + raise exc.AmbiguousForeignKeysError( "Can't determine join between '%s' and '%s'; " "tables have more than one foreign key " "constraint relationship between them. " -- cgit v1.2.1