summaryrefslogtreecommitdiff
path: root/examples/custom_attributes/custom_management.py
blob: aa9ea7a68998eb1ccf95f0917ea1ba117c806696 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Illustrates customized class instrumentation, using
the :mod:`sqlalchemy.ext.instrumentation` extension package.

In this example, mapped classes are modified to
store their state in a dictionary attached to an attribute
named "_goofy_dict", instead of using __dict__.
this example illustrates how to replace SQLAlchemy's class
descriptors with a user-defined system.


"""
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import Text
from sqlalchemy.ext.instrumentation import InstrumentationManager
from sqlalchemy.orm import registry as _reg
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import del_attribute
from sqlalchemy.orm.attributes import get_attribute
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.instrumentation import is_instrumented


registry = _reg()


class MyClassState(InstrumentationManager):
    def get_instance_dict(self, class_, instance):
        return instance._goofy_dict

    def initialize_instance_dict(self, class_, instance):
        instance.__dict__["_goofy_dict"] = {}

    def install_state(self, class_, instance, state):
        instance.__dict__["_goofy_dict"]["state"] = state

    def state_getter(self, class_):
        def find(instance):
            return instance.__dict__["_goofy_dict"]["state"]

        return find


class MyClass:
    __sa_instrumentation_manager__ = MyClassState

    def __init__(self, **kwargs):
        for k in kwargs:
            setattr(self, k, kwargs[k])

    def __getattr__(self, key):
        if is_instrumented(self, key):
            return get_attribute(self, key)
        else:
            try:
                return self._goofy_dict[key]
            except KeyError:
                raise AttributeError(key)

    def __setattr__(self, key, value):
        if is_instrumented(self, key):
            set_attribute(self, key, value)
        else:
            self._goofy_dict[key] = value

    def __delattr__(self, key):
        if is_instrumented(self, key):
            del_attribute(self, key)
        else:
            del self._goofy_dict[key]


if __name__ == "__main__":
    engine = create_engine("sqlite://")
    meta = MetaData()

    table1 = Table(
        "table1",
        meta,
        Column("id", Integer, primary_key=True),
        Column("name", Text),
    )
    table2 = Table(
        "table2",
        meta,
        Column("id", Integer, primary_key=True),
        Column("name", Text),
        Column("t1id", Integer, ForeignKey("table1.id")),
    )
    meta.create_all(engine)

    class A(MyClass):
        pass

    class B(MyClass):
        pass

    registry.map_imperatively(A, table1, properties={"bs": relationship(B)})

    registry.map_imperatively(B, table2)

    a1 = A(name="a1", bs=[B(name="b1"), B(name="b2")])

    assert a1.name == "a1"
    assert a1.bs[0].name == "b1"

    sess = Session(engine)
    sess.add(a1)

    sess.commit()

    a1 = sess.query(A).get(a1.id)

    assert a1.name == "a1"
    assert a1.bs[0].name == "b1"

    a1.bs.remove(a1.bs[0])

    sess.commit()

    a1 = sess.query(A).get(a1.id)
    assert len(a1.bs) == 1