summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py118
1 files changed, 60 insertions, 58 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 83a75a77f..b0818b891 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -12,7 +12,7 @@ from collections import deque
"""Utility functions that build upon SQL and Schema constructs."""
-def sort_tables(tables):
+def sort_tables(tables, skip_fn=None):
"""sort a collection of Table objects in order of their foreign-key dependency."""
tables = list(tables)
@@ -20,6 +20,8 @@ def sort_tables(tables):
def visit_foreign_key(fkey):
if fkey.use_alter:
return
+ elif skip_fn and skip_fn(fkey):
+ return
parent_table = fkey.column.table
if parent_table in tables:
child_table = fkey.parent.table
@@ -27,8 +29,8 @@ def sort_tables(tables):
tuples.append((parent_table, child_table))
for table in tables:
- visitors.traverse(table,
- {'schema_visitor':True},
+ visitors.traverse(table,
+ {'schema_visitor':True},
{'foreign_key':visit_foreign_key})
tuples.extend(
@@ -38,9 +40,9 @@ def sort_tables(tables):
return list(topological.sort(tuples, tables))
def find_join_source(clauses, join_to):
- """Given a list of FROM clauses and a selectable,
- return the first index and element from the list of
- clauses which can be joined against the selectable. returns
+ """Given a list of FROM clauses and a selectable,
+ return the first index and element from the list of
+ clauses which can be joined against the selectable. returns
None, None if no match is found.
e.g.::
@@ -66,25 +68,25 @@ def find_join_source(clauses, join_to):
def visit_binary_product(fn, expr):
"""Produce a traversal of the given expression, delivering
column comparisons to the given function.
-
+
The function is of the form::
-
+
def my_fn(binary, left, right)
-
- For each binary expression located which has a
+
+ For each binary expression located which has a
comparison operator, the product of "left" and
"right" will be delivered to that function,
in terms of that binary.
-
+
Hence an expression like::
-
+
and_(
(a + b) == q + func.sum(e + f),
j == r
)
-
+
would have the traversal::
-
+
a <eq> q
a <eq> e
a <eq> f
@@ -96,7 +98,7 @@ def visit_binary_product(fn, expr):
That is, every combination of "left" and
"right" that doesn't further contain
a binary comparison is passed as pairs.
-
+
"""
stack = []
def visit(element):
@@ -121,8 +123,8 @@ def visit_binary_product(fn, expr):
yield e
list(visit(expr))
-def find_tables(clause, check_columns=False,
- include_aliases=False, include_joins=False,
+def find_tables(clause, check_columns=False,
+ include_aliases=False, include_joins=False,
include_selects=False, include_crud=False):
"""locate Table objects within the given expression."""
@@ -171,7 +173,7 @@ def unwrap_order_by(clause):
(
not isinstance(t, expression._UnaryExpression) or \
not operators.is_ordering_modifier(t.modifier)
- ):
+ ):
cols.add(t)
else:
for c in t.get_children():
@@ -226,7 +228,7 @@ def _quote_ddl_expr(element):
class _repr_params(object):
"""A string view of bound parameters, truncating
display to the given number of 'multi' parameter sets.
-
+
"""
def __init__(self, params, batches):
self.params = params
@@ -246,7 +248,7 @@ class _repr_params(object):
def expression_as_ddl(clause):
- """Given a SQL expression, convert for usage in DDL, such as
+ """Given a SQL expression, convert for usage in DDL, such as
CREATE INDEX and CHECK CONSTRAINT.
Converts bind params into quoted literals, column identifiers
@@ -285,7 +287,7 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {'binary':visit_binary})
-def join_condition(a, b, ignore_nonexistent_tables=False,
+def join_condition(a, b, ignore_nonexistent_tables=False,
a_subset=None,
consider_as_foreign_keys=None):
"""create a join condition between two tables or selectables.
@@ -321,7 +323,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False,
if left is None:
continue
for fk in sorted(
- b.foreign_keys,
+ b.foreign_keys,
key=lambda fk:fk.parent._creation_order):
if consider_as_foreign_keys is not None and \
fk.parent not in consider_as_foreign_keys:
@@ -339,7 +341,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False,
constraints.add(fk.constraint)
if left is not b:
for fk in sorted(
- left.foreign_keys,
+ left.foreign_keys,
key=lambda fk:fk.parent._creation_order):
if consider_as_foreign_keys is not None and \
fk.parent not in consider_as_foreign_keys:
@@ -385,12 +387,12 @@ def join_condition(a, b, ignore_nonexistent_tables=False,
class Annotated(object):
"""clones a ClauseElement and applies an 'annotations' dictionary.
- Unlike regular clones, this clone also mimics __hash__() and
+ Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
in hashed collections.
A reference to the original element is maintained, for the important
- reason of keeping its hash value current. When GC'ed, the
+ reason of keeping its hash value current. When GC'ed, the
hash value may be reused, causing conflicts.
"""
@@ -406,13 +408,13 @@ class Annotated(object):
try:
cls = annotated_classes[element.__class__]
except KeyError:
- cls = annotated_classes[element.__class__] = type.__new__(type,
- "Annotated%s" % element.__class__.__name__,
+ cls = annotated_classes[element.__class__] = type.__new__(type,
+ "Annotated%s" % element.__class__.__name__,
(Annotated, element.__class__), {})
return object.__new__(cls)
def __init__(self, element, values):
- # force FromClause to generate their internal
+ # force FromClause to generate their internal
# collections into __dict__
if isinstance(element, expression.FromClause):
element.c
@@ -481,7 +483,7 @@ for cls in expression.__dict__.values() + [schema.Column, schema.Table]:
exec "annotated_classes[cls] = Annotated%s" % (cls.__name__)
def _deep_annotate(element, annotations, exclude=None):
- """Deep copy the given ClauseElement, annotating each element
+ """Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
Elements within the exclude collection will be cloned but not annotated.
@@ -528,17 +530,17 @@ def _deep_deannotate(element, values=None):
element = clone(element)
return element
-def _shallow_annotate(element, annotations):
- """Annotate the given ClauseElement and copy its internals so that
- internal objects refer to the new annotated object.
+def _shallow_annotate(element, annotations):
+ """Annotate the given ClauseElement and copy its internals so that
+ internal objects refer to the new annotated object.
- Basically used to apply a "dont traverse" annotation to a
- selectable, without digging throughout the whole
- structure wasting time.
- """
- element = element._annotate(annotations)
- element._copy_internals()
- return element
+ Basically used to apply a "dont traverse" annotation to a
+ selectable, without digging throughout the whole
+ structure wasting time.
+ """
+ element = element._annotate(annotations)
+ element._copy_internals()
+ return element
def splice_joins(left, right, stop_on=None):
if left is None:
@@ -626,7 +628,7 @@ def reduce_columns(columns, *clauses, **kw):
return expression.ColumnSet(columns.difference(omit))
-def criterion_as_pairs(expression, consider_as_foreign_keys=None,
+def criterion_as_pairs(expression, consider_as_foreign_keys=None,
consider_as_referenced_keys=None, any_operator=False):
"""traverse an expression and locate binary criterion pairs."""
@@ -648,20 +650,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
if consider_as_foreign_keys:
if binary.left in consider_as_foreign_keys and \
- (col_is(binary.right, binary.left) or
+ (col_is(binary.right, binary.left) or
binary.right not in consider_as_foreign_keys):
pairs.append((binary.right, binary.left))
elif binary.right in consider_as_foreign_keys and \
- (col_is(binary.left, binary.right) or
+ (col_is(binary.left, binary.right) or
binary.left not in consider_as_foreign_keys):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
if binary.left in consider_as_referenced_keys and \
- (col_is(binary.right, binary.left) or
+ (col_is(binary.right, binary.left) or
binary.right not in consider_as_referenced_keys):
pairs.append((binary.left, binary.right))
elif binary.right in consider_as_referenced_keys and \
- (col_is(binary.left, binary.right) or
+ (col_is(binary.left, binary.right) or
binary.left not in consider_as_referenced_keys):
pairs.append((binary.right, binary.left))
else:
@@ -678,17 +680,17 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
def folded_equivalents(join, equivs=None):
"""Return a list of uniquely named columns.
- The column list of the given Join will be narrowed
+ The column list of the given Join will be narrowed
down to a list of all equivalently-named,
equated columns folded into one column, where 'equated' means they are
equated to each other in the ON clause of this join.
This function is used by Join.select(fold_equivalents=True).
- Deprecated. This function is used for a certain kind of
+ Deprecated. This function is used for a certain kind of
"polymorphic_union" which is designed to achieve joined
table inheritance where the base table has no "discriminator"
- column; [ticket:1131] will provide a better way to
+ column; [ticket:1131] will provide a better way to
achieve this.
"""
@@ -773,9 +775,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
s.c.col1 == table2.c.col1
"""
- def __init__(self, selectable, equivalents=None,
- include=None, exclude=None,
- include_fn=None, exclude_fn=None,
+ def __init__(self, selectable, equivalents=None,
+ include=None, exclude=None,
+ include_fn=None, exclude_fn=None,
adapt_on_names=False):
self.__traverse_options__ = {'stop_on':[selectable]}
self.selectable = selectable
@@ -794,12 +796,12 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET):
newcol = self.selectable.corresponding_column(
- col,
+ col,
require_embedded=require_embedded)
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
- newcol = self._corresponding_column(equiv,
- require_embedded=require_embedded,
+ newcol = self._corresponding_column(equiv,
+ require_embedded=require_embedded,
_seen=_seen.union([col]))
if newcol is not None:
return newcol
@@ -823,14 +825,14 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
class ColumnAdapter(ClauseAdapter):
"""Extends ClauseAdapter with extra utility functions.
- Provides the ability to "wrap" this ClauseAdapter
+ Provides the ability to "wrap" this ClauseAdapter
around another, a columns dictionary which returns
- adapted elements given an original, and an
+ adapted elements given an original, and an
adapted_row() factory.
"""
- def __init__(self, selectable, equivalents=None,
- chain_to=None, include=None,
+ def __init__(self, selectable, equivalents=None,
+ chain_to=None, include=None,
exclude=None, adapt_required=False):
ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
if chain_to:
@@ -866,7 +868,7 @@ class ColumnAdapter(ClauseAdapter):
c = c.label(None)
# adapt_required indicates that if we got the same column
- # back which we put in (i.e. it passed through),
+ # back which we put in (i.e. it passed through),
# it's not correct. this is used by eagerloading which
# knows that all columns and expressions need to be adapted
# to a result row, and a "passthrough" is definitely targeting