summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/visitors.py
blob: 98e4de6c3341ebcb4aff969598241702dd28940d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class ClauseVisitor(object):
    """A class that knows how to traverse and visit
    ``ClauseElements``.
    
    Calls visit_XXX() methods dynamically generated for each particualr
    ``ClauseElement`` subclass encountered.  Traversal of a
    hierarchy of ``ClauseElements`` is achieved via the
    ``traverse()`` method, which is passed the lead
    ``ClauseElement``.
    
    By default, ``ClauseVisitor`` traverses all elements
    fully.  Options can be specified at the class level via the 
    ``__traverse_options__`` dictionary which will be passed
    to the ``get_children()`` method of each ``ClauseElement``;
    these options can indicate modifications to the set of 
    elements returned, such as to not return column collections
    (column_collections=False) or to return Schema-level items
    (schema_visitor=True).
    
    ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
    operation, which will produce a copy of a given ``ClauseElement``
    structure while at the same time allowing ``ClauseVisitor`` subclasses
    to modify the new structure in-place.
    
    """
    __traverse_options__ = {}
    
    def traverse_single(self, obj, **kwargs):
        meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
        if meth:
            return meth(obj, **kwargs)

    def iterate(self, obj, stop_on=None):
        stack = [obj]
        traversal = []
        while len(stack) > 0:
            t = stack.pop()
            if stop_on is None or t not in stop_on:
                yield t
                traversal.insert(0, t)
                for c in t.get_children(**self.__traverse_options__):
                    stack.append(c)
        
    def traverse(self, obj, stop_on=None, clone=False):
        if clone:
            obj = obj._clone()
            
        stack = [obj]
        traversal = []
        while len(stack) > 0:
            t = stack.pop()
            if stop_on is None or t not in stop_on:
                traversal.insert(0, t)
                if clone:
                    t._copy_internals()
                for c in t.get_children(**self.__traverse_options__):
                    stack.append(c)
        for target in traversal:
            v = self
            while v is not None:
                meth = getattr(v, "visit_%s" % target.__visit_name__, None)
                if meth:
                    meth(target)
                v = getattr(v, '_next', None)
        return obj

    def chain(self, visitor):
        """'chain' an additional ClauseVisitor onto this ClauseVisitor.
        
        the chained visitor will receive all visit events after this one."""
        tail = self
        while getattr(tail, '_next', None) is not None:
            tail = tail._next
        tail._next = visitor
        return self

class NoColumnVisitor(ClauseVisitor):
    """a ClauseVisitor that will not traverse the exported Column 
    collections on Table, Alias, Select, and CompoundSelect objects
    (i.e. their 'columns' or 'c' attribute).
    
    this is useful because most traversals don't need those columns, or
    in the case of DefaultCompiler it traverses them explicitly; so
    skipping their traversal here greatly cuts down on method call overhead.
    """
    
    __traverse_options__ = {'column_collections':False}