summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/cyextension/resultproxy.pyx
blob: daf5cc9400905ef6d036d98228dd3c0dd9a88719 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# TODO: this is mostly just copied over from the python implementation
# more improvements are likely possible
import operator

cdef int MD_INDEX = 0  # integer index in cursor.description

KEY_INTEGER_ONLY = 0
KEY_OBJECTS_ONLY = 1

sqlalchemy_engine_row = None

cdef class BaseRow:
    cdef readonly object _parent
    cdef readonly tuple _data
    cdef readonly dict _keymap
    cdef readonly int _key_style

    def __init__(self, object parent, object processors, dict keymap, int key_style, object data):
        """Row objects are constructed by CursorResult objects."""

        self._parent = parent

        if processors:
            self._data = tuple(
                [
                    proc(value) if proc else value
                    for proc, value in zip(processors, data)
                ]
            )
        else:
            self._data = tuple(data)

        self._keymap = keymap

        self._key_style = key_style

    def __reduce__(self):
        return (
            rowproxy_reconstructor,
            (self.__class__, self.__getstate__()),
        )

    def __getstate__(self):
        return {
            "_parent": self._parent,
            "_data": self._data,
            "_key_style": self._key_style,
        }

    def __setstate__(self, dict state):
        self._parent = state["_parent"]
        self._data = state["_data"]
        self._keymap = self._parent._keymap
        self._key_style = state["_key_style"]

    def _filter_on_values(self, filters):
        global sqlalchemy_engine_row
        if sqlalchemy_engine_row is None:
            from sqlalchemy.engine.row import Row as sqlalchemy_engine_row

        return sqlalchemy_engine_row(
            self._parent,
            filters,
            self._keymap,
            self._key_style,
            self._data,
        )

    def _values_impl(self):
        return list(self)

    def __iter__(self):
        return iter(self._data)

    def __len__(self):
        return len(self._data)

    def __hash__(self):
        return hash(self._data)

    def _get_by_int_impl(self, key):
        return self._data[key]

    cpdef _get_by_key_impl(self, key):
        # keep two isinstance since it's noticeably faster in the int case
        if isinstance(key, int) or isinstance(key, slice):
            return self._data[key]

        self._parent._raise_for_nonint(key)

    def __getitem__(self, key):
        return self._get_by_key_impl(key)

    cpdef _get_by_key_impl_mapping(self, key):
        try:
            rec = self._keymap[key]
        except KeyError as ke:
            rec = self._parent._key_fallback(key, ke)

        mdindex = rec[MD_INDEX]
        if mdindex is None:
            self._parent._raise_for_ambiguous_column_name(rec)
        elif (
            self._key_style == KEY_OBJECTS_ONLY
            and isinstance(key, int)
        ):
            raise KeyError(key)

        return self._data[mdindex]

    def __getattr__(self, name):
        try:
            return self._get_by_key_impl_mapping(name)
        except KeyError as e:
           raise AttributeError(e.args[0]) from e


def rowproxy_reconstructor(cls, state):
    obj = cls.__new__(cls)
    obj.__setstate__(state)
    return obj


def tuplegetter(*indexes):
    it = operator.itemgetter(*indexes)

    if len(indexes) > 1:
        return it
    else:
        return lambda row: (it(row),)