diff options
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r-- | lib/sqlalchemy/sql/util.py | 118 |
1 files changed, 60 insertions, 58 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 83a75a77f..b0818b891 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -12,7 +12,7 @@ from collections import deque """Utility functions that build upon SQL and Schema constructs.""" -def sort_tables(tables): +def sort_tables(tables, skip_fn=None): """sort a collection of Table objects in order of their foreign-key dependency.""" tables = list(tables) @@ -20,6 +20,8 @@ def sort_tables(tables): def visit_foreign_key(fkey): if fkey.use_alter: return + elif skip_fn and skip_fn(fkey): + return parent_table = fkey.column.table if parent_table in tables: child_table = fkey.parent.table @@ -27,8 +29,8 @@ def sort_tables(tables): tuples.append((parent_table, child_table)) for table in tables: - visitors.traverse(table, - {'schema_visitor':True}, + visitors.traverse(table, + {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) tuples.extend( @@ -38,9 +40,9 @@ def sort_tables(tables): return list(topological.sort(tuples, tables)) def find_join_source(clauses, join_to): - """Given a list of FROM clauses and a selectable, - return the first index and element from the list of - clauses which can be joined against the selectable. returns + """Given a list of FROM clauses and a selectable, + return the first index and element from the list of + clauses which can be joined against the selectable. returns None, None if no match is found. e.g.:: @@ -66,25 +68,25 @@ def find_join_source(clauses, join_to): def visit_binary_product(fn, expr): """Produce a traversal of the given expression, delivering column comparisons to the given function. - + The function is of the form:: - + def my_fn(binary, left, right) - - For each binary expression located which has a + + For each binary expression located which has a comparison operator, the product of "left" and "right" will be delivered to that function, in terms of that binary. - + Hence an expression like:: - + and_( (a + b) == q + func.sum(e + f), j == r ) - + would have the traversal:: - + a <eq> q a <eq> e a <eq> f @@ -96,7 +98,7 @@ def visit_binary_product(fn, expr): That is, every combination of "left" and "right" that doesn't further contain a binary comparison is passed as pairs. - + """ stack = [] def visit(element): @@ -121,8 +123,8 @@ def visit_binary_product(fn, expr): yield e list(visit(expr)) -def find_tables(clause, check_columns=False, - include_aliases=False, include_joins=False, +def find_tables(clause, check_columns=False, + include_aliases=False, include_joins=False, include_selects=False, include_crud=False): """locate Table objects within the given expression.""" @@ -171,7 +173,7 @@ def unwrap_order_by(clause): ( not isinstance(t, expression._UnaryExpression) or \ not operators.is_ordering_modifier(t.modifier) - ): + ): cols.add(t) else: for c in t.get_children(): @@ -226,7 +228,7 @@ def _quote_ddl_expr(element): class _repr_params(object): """A string view of bound parameters, truncating display to the given number of 'multi' parameter sets. - + """ def __init__(self, params, batches): self.params = params @@ -246,7 +248,7 @@ class _repr_params(object): def expression_as_ddl(clause): - """Given a SQL expression, convert for usage in DDL, such as + """Given a SQL expression, convert for usage in DDL, such as CREATE INDEX and CHECK CONSTRAINT. Converts bind params into quoted literals, column identifiers @@ -285,7 +287,7 @@ def adapt_criterion_to_null(crit, nulls): return visitors.cloned_traverse(crit, {}, {'binary':visit_binary}) -def join_condition(a, b, ignore_nonexistent_tables=False, +def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None, consider_as_foreign_keys=None): """create a join condition between two tables or selectables. @@ -321,7 +323,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, if left is None: continue for fk in sorted( - b.foreign_keys, + b.foreign_keys, key=lambda fk:fk.parent._creation_order): if consider_as_foreign_keys is not None and \ fk.parent not in consider_as_foreign_keys: @@ -339,7 +341,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False, constraints.add(fk.constraint) if left is not b: for fk in sorted( - left.foreign_keys, + left.foreign_keys, key=lambda fk:fk.parent._creation_order): if consider_as_foreign_keys is not None and \ fk.parent not in consider_as_foreign_keys: @@ -385,12 +387,12 @@ def join_condition(a, b, ignore_nonexistent_tables=False, class Annotated(object): """clones a ClauseElement and applies an 'annotations' dictionary. - Unlike regular clones, this clone also mimics __hash__() and + Unlike regular clones, this clone also mimics __hash__() and __cmp__() of the original element so that it takes its place in hashed collections. A reference to the original element is maintained, for the important - reason of keeping its hash value current. When GC'ed, the + reason of keeping its hash value current. When GC'ed, the hash value may be reused, causing conflicts. """ @@ -406,13 +408,13 @@ class Annotated(object): try: cls = annotated_classes[element.__class__] except KeyError: - cls = annotated_classes[element.__class__] = type.__new__(type, - "Annotated%s" % element.__class__.__name__, + cls = annotated_classes[element.__class__] = type.__new__(type, + "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {}) return object.__new__(cls) def __init__(self, element, values): - # force FromClause to generate their internal + # force FromClause to generate their internal # collections into __dict__ if isinstance(element, expression.FromClause): element.c @@ -481,7 +483,7 @@ for cls in expression.__dict__.values() + [schema.Column, schema.Table]: exec "annotated_classes[cls] = Annotated%s" % (cls.__name__) def _deep_annotate(element, annotations, exclude=None): - """Deep copy the given ClauseElement, annotating each element + """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. Elements within the exclude collection will be cloned but not annotated. @@ -528,17 +530,17 @@ def _deep_deannotate(element, values=None): element = clone(element) return element -def _shallow_annotate(element, annotations): - """Annotate the given ClauseElement and copy its internals so that - internal objects refer to the new annotated object. +def _shallow_annotate(element, annotations): + """Annotate the given ClauseElement and copy its internals so that + internal objects refer to the new annotated object. - Basically used to apply a "dont traverse" annotation to a - selectable, without digging throughout the whole - structure wasting time. - """ - element = element._annotate(annotations) - element._copy_internals() - return element + Basically used to apply a "dont traverse" annotation to a + selectable, without digging throughout the whole + structure wasting time. + """ + element = element._annotate(annotations) + element._copy_internals() + return element def splice_joins(left, right, stop_on=None): if left is None: @@ -626,7 +628,7 @@ def reduce_columns(columns, *clauses, **kw): return expression.ColumnSet(columns.difference(omit)) -def criterion_as_pairs(expression, consider_as_foreign_keys=None, +def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False): """traverse an expression and locate binary criterion pairs.""" @@ -648,20 +650,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, if consider_as_foreign_keys: if binary.left in consider_as_foreign_keys and \ - (col_is(binary.right, binary.left) or + (col_is(binary.right, binary.left) or binary.right not in consider_as_foreign_keys): pairs.append((binary.right, binary.left)) elif binary.right in consider_as_foreign_keys and \ - (col_is(binary.left, binary.right) or + (col_is(binary.left, binary.right) or binary.left not in consider_as_foreign_keys): pairs.append((binary.left, binary.right)) elif consider_as_referenced_keys: if binary.left in consider_as_referenced_keys and \ - (col_is(binary.right, binary.left) or + (col_is(binary.right, binary.left) or binary.right not in consider_as_referenced_keys): pairs.append((binary.left, binary.right)) elif binary.right in consider_as_referenced_keys and \ - (col_is(binary.left, binary.right) or + (col_is(binary.left, binary.right) or binary.left not in consider_as_referenced_keys): pairs.append((binary.right, binary.left)) else: @@ -678,17 +680,17 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, def folded_equivalents(join, equivs=None): """Return a list of uniquely named columns. - The column list of the given Join will be narrowed + The column list of the given Join will be narrowed down to a list of all equivalently-named, equated columns folded into one column, where 'equated' means they are equated to each other in the ON clause of this join. This function is used by Join.select(fold_equivalents=True). - Deprecated. This function is used for a certain kind of + Deprecated. This function is used for a certain kind of "polymorphic_union" which is designed to achieve joined table inheritance where the base table has no "discriminator" - column; [ticket:1131] will provide a better way to + column; [ticket:1131] will provide a better way to achieve this. """ @@ -773,9 +775,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable, equivalents=None, - include=None, exclude=None, - include_fn=None, exclude_fn=None, + def __init__(self, selectable, equivalents=None, + include=None, exclude=None, + include_fn=None, exclude_fn=None, adapt_on_names=False): self.__traverse_options__ = {'stop_on':[selectable]} self.selectable = selectable @@ -794,12 +796,12 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET): newcol = self.selectable.corresponding_column( - col, + col, require_embedded=require_embedded) if newcol is None and col in self.equivalents and col not in _seen: for equiv in self.equivalents[col]: - newcol = self._corresponding_column(equiv, - require_embedded=require_embedded, + newcol = self._corresponding_column(equiv, + require_embedded=require_embedded, _seen=_seen.union([col])) if newcol is not None: return newcol @@ -823,14 +825,14 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): class ColumnAdapter(ClauseAdapter): """Extends ClauseAdapter with extra utility functions. - Provides the ability to "wrap" this ClauseAdapter + Provides the ability to "wrap" this ClauseAdapter around another, a columns dictionary which returns - adapted elements given an original, and an + adapted elements given an original, and an adapted_row() factory. """ - def __init__(self, selectable, equivalents=None, - chain_to=None, include=None, + def __init__(self, selectable, equivalents=None, + chain_to=None, include=None, exclude=None, adapt_required=False): ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) if chain_to: @@ -866,7 +868,7 @@ class ColumnAdapter(ClauseAdapter): c = c.label(None) # adapt_required indicates that if we got the same column - # back which we put in (i.e. it passed through), + # back which we put in (i.e. it passed through), # it's not correct. this is used by eagerloading which # knows that all columns and expressions need to be adapted # to a result row, and a "passthrough" is definitely targeting |