diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-07-05 10:19:59 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-07-05 10:19:59 -0400 |
commit | b297b40fca923a03e3c34094e5298d6524944c39 (patch) | |
tree | 286e1440f00f2f730c2ec59aa50fceb4e19b09fe /lib/sqlalchemy/sql/util.py | |
parent | 838c4eca94918b8db38eeb7faf48e63d6b2375b0 (diff) | |
download | sqlalchemy-b297b40fca923a03e3c34094e5298d6524944c39.tar.gz |
- [bug] ORM will perform extra effort to determine
that an FK dependency between two tables is
not significant during flush if the tables
are related via joined inheritance and the FK
dependency is not part of the inherit_condition,
saves the user a use_alter directive.
[ticket:2527]
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 118 |
1 files changed, 60 insertions, 58 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 83a75a77f..b0818b891 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -12,7 +12,7 @@ from collections import deque """Utility functions that build upon SQL and Schema constructs.""" -def sort_tables(tables): +def sort_tables(tables, skip_fn=None): """sort a collection of Table objects in order of their foreign-key dependency.""" tables = list(tables) @@ -20,6 +20,8 @@ def sort_tables(tables): def visit_foreign_key(fkey): if fkey.use_alter: return + elif skip_fn and skip_fn(fkey): + return parent_table = fkey.column.table if parent_table in tables: child_table = fkey.parent.table @@ -27,8 +29,8 @@ def sort_tables(tables): tuples.append((parent_table, child_table)) for table in tables: - visitors.traverse(table, - {'schema_visitor':True}, + visitors.traverse(table, + {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) tuples.extend( @@ -38,9 +40,9 @@ def sort_tables(tables): return list(topological.sort(tuples, tables)) def find_join_source(clauses, join_to): - """Given a list of FROM clauses and a selectable, - return the first index and element from the list of - clauses which can be joined against the selectable. returns + """Given a list of FROM clauses and a selectable, + return the first index and element from the list of + clauses which can be joined against the selectable. returns None, None if no match is found. e.g.:: @@ -66,25 +68,25 @@ def find_join_source(clauses, join_to): def visit_binary_product(fn, expr): """Produce a traversal of the given expression, delivering column comparisons to the given function. - + The function is of the form:: - + def my_fn(binary, left, right) - - For each binary expression located which has a + + For each binary expression located which has a comparison operator, the product of "left" and "right" will be delivered to that function, in terms of that binary. - + Hence an expression like:: - + and_( (a + b) == q + func.sum(e + f), j == r ) - + would have the traversal:: - + a <eq> q a <eq> e a <eq> f @@ -96,7 +98,7 @@ def visit_binary_product(fn, expr): That is, every combination of "left" and "right" that doesn't further contain a binary comparison is passed as pairs. - + """ stack = [] def visit(element): @@ -121,8 +123,8 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) -def find_tables(clause, check_columns=False, - include_aliases=False, include_joins=False, +def find_tables(clause, check_columns=False, + include_aliases=False, include_joins=False, include_selects=False, include_crud=False): """locate Table objects within the given expression.""" @@ -171,7 +173,7 @@ def unwrap_order_by(clause): ( not isinstance(t, expression._UnaryExpression) or \ not operators.is_ordering_modifier(t.modifier) - ): + ): cols.add(t) else: for c in t.get_children(): @@ -226,7 +228,7 @@ def _quote_ddl_expr(element): class _repr_params(object): """A string view of bound parameters, truncating display to the given number of 'multi' parameter sets. - + """ def __init__(self, params, batches): self.params = params @@ -246,7 +248,7 @@ class _repr_params(object): def expression_as_ddl(clause): - """Given a SQL expression, convert for usage in DDL, such as + """Given a SQL expression, convert for usage in DDL, such as CREATE INDEX and CHECK CONSTRAINT. Converts bind params into quoted literals, column identifiers @@ -285,7 +287,7 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) -def join_condition(a, b, ignore_nonexistent_tables=False, +def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None, consider_as_foreign_keys=None): """create a join condition between two tables or selectables. @@ -321,7 +323,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, if left is None: continue for fk in sorted( - b.foreign_keys, + b.foreign_keys, key=lambda fk:fk.parent._creation_order): if consider_as_foreign_keys is not None and \ fk.parent not in consider_as_foreign_keys: @@ -339,7 +341,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, constraints.add(fk.constraint) if left is not b: for fk in sorted( - left.foreign_keys, + left.foreign_keys, key=lambda fk:fk.parent._creation_order): if consider_as_foreign_keys is not None and \ fk.parent not in consider_as_foreign_keys: @@ -385,12 +387,12 @@ def join_condition(a, b, ignore_nonexistent_tables=False, class Annotated(object): """clones a ClauseElement and applies an 'annotations' dictionary. - Unlike regular clones, this clone also mimics __hash__() and + Unlike regular clones, this clone also mimics __hash__() and __cmp__() of the original element so that it takes its place in hashed collections. A reference to the original element is maintained, for the important - reason of keeping its hash value current. When GC'ed, the + reason of keeping its hash value current. When GC'ed, the hash value may be reused, causing conflicts. """ @@ -406,13 +408,13 @@ class Annotated(object): try: cls = annotated_classes[element.__class__] except KeyError: - cls = annotated_classes[element.__class__] = type.__new__(type, - "Annotated%s" % element.__class__.__name__, + cls = annotated_classes[element.__class__] = type.__new__(type, + "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {}) return object.__new__(cls) def __init__(self, element, values): - # force FromClause to generate their internal + # force FromClause to generate their internal # collections into __dict__ if isinstance(element, expression.FromClause): element.c @@ -481,7 +483,7 @@ for cls in expression.__dict__.values() + [schema.Column, schema.Table]: exec "annotated_classes[cls] = Annotated%s" % (cls.__name__) def _deep_annotate(element, annotations, exclude=None): - """Deep copy the given ClauseElement, annotating each element + """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. Elements within the exclude collection will be cloned but not annotated. @@ -528,17 +530,17 @@ def _deep_deannotate(element, values=None): element = clone(element) return element -def _shallow_annotate(element, annotations): - """Annotate the given ClauseElement and copy its internals so that - internal objects refer to the new annotated object. +def _shallow_annotate(element, annotations): + """Annotate the given ClauseElement and copy its internals so that + internal objects refer to the new annotated object. - Basically used to apply a "dont traverse" annotation to a - selectable, without digging throughout the whole - structure wasting time. - """ - element = element._annotate(annotations) - element._copy_internals() - return element + Basically used to apply a "dont traverse" annotation to a + selectable, without digging throughout the whole + structure wasting time. + """ + element = element._annotate(annotations) + element._copy_internals() + return element def splice_joins(left, right, stop_on=None): if left is None: @@ -626,7 +628,7 @@ def reduce_columns(columns, *clauses, **kw): return expression.ColumnSet(columns.difference(omit)) -def criterion_as_pairs(expression, consider_as_foreign_keys=None, +def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False): """traverse an expression and locate binary criterion pairs.""" @@ -648,20 +650,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ - (col_is(binary.right, binary.left) or + (col_is(binary.right, binary.left) or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ - (col_is(binary.left, binary.right) or + (col_is(binary.left, binary.right) or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ - (col_is(binary.right, binary.left) or + (col_is(binary.right, binary.left) or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ - (col_is(binary.left, binary.right) or + (col_is(binary.left, binary.right) or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: @@ -678,17 +680,17 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, def folded_equivalents(join, equivs=None): """Return a list of uniquely named columns. - The column list of the given Join will be narrowed + The column list of the given Join will be narrowed down to a list of all equivalently-named, equated columns folded into one column, where 'equated' means they are equated to each other in the ON clause of this join. This function is used by Join.select(fold_equivalents=True). - Deprecated. This function is used for a certain kind of + Deprecated. This function is used for a certain kind of "polymorphic_union" which is designed to achieve joined table inheritance where the base table has no "discriminator" - column; [ticket:1131] will provide a better way to + column; [ticket:1131] will provide a better way to achieve this. """ @@ -773,9 +775,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable, equivalents=None, - include=None, exclude=None, - include_fn=None, exclude_fn=None, + def __init__(self, selectable, equivalents=None, + include=None, exclude=None, + include_fn=None, exclude_fn=None, adapt_on_names=False): self.__traverse_options__ = {'stop_on':[selectable]} self.selectable = selectable @@ -794,12 +796,12 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column( - col, + col, require_embedded=require_embedded) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: - newcol = self._corresponding_column(equiv, - require_embedded=require_embedded, + newcol = self._corresponding_column(equiv, + require_embedded=require_embedded, _seen=_seen.union([col])) if newcol is not None: return newcol @@ -823,14 +825,14 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. - Provides the ability to "wrap" this ClauseAdapter + Provides the ability to "wrap" this ClauseAdapter around another, a columns dictionary which returns - adapted elements given an original, and an + adapted elements given an original, and an adapted_row() factory. """ - def __init__(self, selectable, equivalents=None, - chain_to=None, include=None, + def __init__(self, selectable, equivalents=None, + chain_to=None, include=None, exclude=None, adapt_required=False): ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) if chain_to: @@ -866,7 +868,7 @@ class ColumnAdapter(ClauseAdapter): c = c.label(None) # adapt_required indicates that if we got the same column - # back which we put in (i.e. it passed through), + # back which we put in (i.e. it passed through), # it's not correct. this is used by eagerloading which # knows that all columns and expressions need to be adapted # to a result row, and a "passthrough" is definitely targeting |