summaryrefslogtreecommitdiff
path: root/examples/adjacency_list/adjacency_list.py
blob: 38503f9f333f705e4bc81e6eb600d9a3bcf36088 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm.collections import attribute_keyed_dict


Base = declarative_base()


class TreeNode(Base):
    __tablename__ = "tree"
    id = Column(Integer, primary_key=True)
    parent_id = Column(Integer, ForeignKey(id))
    name = Column(String(50), nullable=False)

    children = relationship(
        "TreeNode",
        # cascade deletions
        cascade="all, delete-orphan",
        # many to one + adjacency list - remote_side
        # is required to reference the 'remote'
        # column in the join condition.
        backref=backref("parent", remote_side=id),
        # children will be represented as a dictionary
        # on the "name" attribute.
        collection_class=attribute_keyed_dict("name"),
    )

    def __init__(self, name, parent=None):
        self.name = name
        self.parent = parent

    def __repr__(self):
        return "TreeNode(name=%r, id=%r, parent_id=%r)" % (
            self.name,
            self.id,
            self.parent_id,
        )

    def dump(self, _indent=0):
        return (
            "   " * _indent
            + repr(self)
            + "\n"
            + "".join([c.dump(_indent + 1) for c in self.children.values()])
        )


if __name__ == "__main__":
    engine = create_engine("sqlite://", echo=True)

    def msg(msg, *args):
        msg = msg % args
        print("\n\n\n" + "-" * len(msg.split("\n")[0]))
        print(msg)
        print("-" * len(msg.split("\n")[0]))

    msg("Creating Tree Table:")

    Base.metadata.create_all(engine)

    session = Session(engine)

    node = TreeNode("rootnode")
    TreeNode("node1", parent=node)
    TreeNode("node3", parent=node)

    node2 = TreeNode("node2")
    TreeNode("subnode1", parent=node2)
    node.children["node2"] = node2
    TreeNode("subnode2", parent=node.children["node2"])

    msg("Created new tree structure:\n%s", node.dump())

    msg("flush + commit:")

    session.add(node)
    session.commit()

    msg("Tree After Save:\n %s", node.dump())

    TreeNode("node4", parent=node)
    TreeNode("subnode3", parent=node.children["node4"])
    TreeNode("subnode4", parent=node.children["node4"])
    TreeNode("subsubnode1", parent=node.children["node4"].children["subnode3"])

    # remove node1 from the parent, which will trigger a delete
    # via the delete-orphan cascade.
    del node.children["node1"]

    msg("Removed node1.  flush + commit:")
    session.commit()

    msg("Tree after save:\n %s", node.dump())

    msg(
        "Emptying out the session entirely, selecting tree on root, using "
        "eager loading to join four levels deep."
    )
    session.expunge_all()
    node = (
        session.query(TreeNode)
        .options(
            joinedload("children")
            .joinedload("children")
            .joinedload("children")
            .joinedload("children")
        )
        .filter(TreeNode.name == "rootnode")
        .first()
    )

    msg("Full Tree:\n%s", node.dump())

    msg("Marking root node as deleted, flush + commit:")

    session.delete(node)
    session.commit()