summaryrefslogtreecommitdiff
path: root/test/orm/test_descriptor.py
blob: cea518e7c608f9d09f8d8ba03d37fc1504852178 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from sqlalchemy import Column
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.orm import aliased
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import descriptor_props
from sqlalchemy.orm.interfaces import PropComparator
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.sql import column
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.util import partial


class MockDescriptor(descriptor_props.DescriptorProperty):
    def __init__(
        self, cls, key, descriptor=None, doc=None, comparator_factory=None
    ):
        self.parent = cls.__mapper__
        self.key = key
        self.doc = doc
        self.descriptor = descriptor
        if comparator_factory:
            self._comparator_factory = partial(comparator_factory, self)
        else:
            self._comparator_factory = lambda mapper: None


class DescriptorInstrumentationTest(fixtures.ORMTest):
    def _fixture(self):
        Base = declarative_base()

        class Foo(Base):
            __tablename__ = "foo"
            id = Column(Integer, primary_key=True)

        return Foo

    def test_fixture(self):
        Foo = self._fixture()

        d = MockDescriptor(Foo, "foo")
        d.instrument_class(Foo.__mapper__)

        assert Foo.foo

    def test_property_wrapped_classlevel(self):
        Foo = self._fixture()
        prop = property(lambda self: None)
        Foo.foo = prop

        d = MockDescriptor(Foo, "foo")
        d.instrument_class(Foo.__mapper__)

        assert Foo().foo is None
        assert Foo.foo is not prop

    def test_property_subclass_wrapped_classlevel(self):
        Foo = self._fixture()

        class myprop(property):
            attr = "bar"

            def method1(self):
                return "method1"

        prop = myprop(lambda self: None)
        Foo.foo = prop

        d = MockDescriptor(Foo, "foo")
        d.instrument_class(Foo.__mapper__)

        assert Foo().foo is None
        assert Foo.foo is not prop
        assert Foo.foo.attr == "bar"
        assert Foo.foo.method1() == "method1"

    def test_comparator(self):
        class Comparator(PropComparator):
            __hash__ = None

            attr = "bar"

            def method1(self):
                return "method1"

            def method2(self, other):
                return "method2"

            def __getitem__(self, key):
                return "value"

            def __eq__(self, other):
                return column("foo") == func.upper(other)

        Foo = self._fixture()
        d = MockDescriptor(Foo, "foo", comparator_factory=Comparator)
        d.instrument_class(Foo.__mapper__)
        eq_(Foo.foo.method1(), "method1")
        eq_(Foo.foo.method2("x"), "method2")
        assert Foo.foo.attr == "bar"
        assert Foo.foo["bar"] == "value"
        eq_((Foo.foo == "bar").__str__(), "foo = upper(:upper_1)")

    def test_aliased_comparator(self):
        class Comparator(ColumnProperty.Comparator):
            __hash__ = None

            def __eq__(self, other):
                return func.foobar(self.__clause_element__()) == func.foobar(
                    other
                )

        Foo = self._fixture()
        Foo._name = Column("name", String)

        def comparator_factory(self, mapper):
            prop = mapper._props["_name"]
            return Comparator(prop, mapper)

        d = MockDescriptor(Foo, "foo", comparator_factory=comparator_factory)
        d.instrument_class(Foo.__mapper__)

        eq_(str(Foo.foo == "ed"), "foobar(foo.name) = foobar(:foobar_1)")
        eq_(
            str(aliased(Foo).foo == "ed"),
            "foobar(foo_1.name) = foobar(:foobar_1)",
        )