diff options
author | Federico Caselli <cfederico87@gmail.com> | 2021-01-01 16:09:01 +0100 |
---|---|---|
committer | Federico Caselli <cfederico87@gmail.com> | 2021-12-17 21:29:05 +0100 |
commit | 76fa211620de167b76846f0e5db5b64b8756ad48 (patch) | |
tree | c435dbf6585b3758dc78ee82bf114e162a25d0e1 /lib/sqlalchemy/cyextension | |
parent | 3543fcc9c9601e81560d055ceadaea05c75815c0 (diff) | |
download | sqlalchemy-workflow_test_cython.tar.gz |
Replace c extension with cython versions.workflow_test_cython
Re-implement c version immutabledict / processors / resultproxy / utils with cython.
Performance is in general in par or better than the c version
Added a collection module that has cython version of OrderedSet and IdentitySet
Added a new test/perf file to compare the implementations.
Run ``python test/perf/compiled_extensions.py all`` to execute the comparison test.
See results here: https://docs.google.com/document/d/1nOcDGojHRtXEkuy4vNXcW_XOJd9gqKhSeALGG3kYr6A/edit?usp=sharing
Fixes: #7256
Change-Id: I2930ef1894b5048210384728118e586e813f6a76
Signed-off-by: Federico Caselli <cfederico87@gmail.com>
Diffstat (limited to 'lib/sqlalchemy/cyextension')
-rw-r--r-- | lib/sqlalchemy/cyextension/.gitignore | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/__init__.py | 0 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/collections.pyx | 393 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/immutabledict.pxd | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/immutabledict.pyx | 100 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/processors.pyx | 91 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/resultproxy.pyx | 130 | ||||
-rw-r--r-- | lib/sqlalchemy/cyextension/util.pyx | 43 |
8 files changed, 764 insertions, 0 deletions
diff --git a/lib/sqlalchemy/cyextension/.gitignore b/lib/sqlalchemy/cyextension/.gitignore new file mode 100644 index 000000000..dfc107eaf --- /dev/null +++ b/lib/sqlalchemy/cyextension/.gitignore @@ -0,0 +1,5 @@ +# cython complied files +*.c +*.o +# cython annotated output +*.html
\ No newline at end of file diff --git a/lib/sqlalchemy/cyextension/__init__.py b/lib/sqlalchemy/cyextension/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/cyextension/__init__.py diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx new file mode 100644 index 000000000..e695d4c62 --- /dev/null +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -0,0 +1,393 @@ +from cpython.dict cimport PyDict_Merge, PyDict_Update +from cpython.long cimport PyLong_FromLong +from cpython.set cimport PySet_Add + +from itertools import filterfalse + +cdef bint add_not_present(set seen, object item, hashfunc): + hash_value = hashfunc(item) + if hash_value not in seen: + PySet_Add(seen, hash_value) + return True + else: + return False + +cdef list cunique_list(seq, hashfunc=None): + cdef set seen = set() + if not hashfunc: + return [x for x in seq if x not in seen and not PySet_Add(seen, x)] + else: + return [x for x in seq if add_not_present(seen, x, hashfunc)] + +def unique_list(seq, hashfunc=None): + return cunique_list(seq, hashfunc) + +cdef class OrderedSet(set): + + cdef list _list + + def __init__(self, d=None): + set.__init__(self) + if d is not None: + self._list = cunique_list(d) + set.update(self, self._list) + else: + self._list = [] + + cdef OrderedSet _copy(self): + cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) + cp._list = list(self._list) + set.update(cp, cp._list) + return cp + + cdef OrderedSet _from_list(self, list new_list): + cdef OrderedSet new = OrderedSet.__new__(OrderedSet) + new._list = new_list + set.update(new, new_list) + return new + + def add(self, element): + if element not in self: + self._list.append(element) + PySet_Add(self, element) + + def remove(self, element): + # set.remove will raise if element is not in self + set.remove(self, element) + self._list.remove(element) + + def insert(self, Py_ssize_t pos, element): + if element not in self: + self._list.insert(pos, element) + PySet_Add(self, element) + + def discard(self, element): + if element in self: + set.remove(self, element) + self._list.remove(element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + return self + + def __ior__(self, iterable): + return self.update(iterable) + + def union(self, other): + result = self._copy() + result.update(other) + return result + + def __or__(self, other): + return self.union(other) + + cdef set _to_set(self, other): + cdef set other_set + if isinstance(other, set): + other_set = <set> other + else: + other_set = set(other) + return other_set + + def intersection(self, other): + cdef set other_set = self._to_set(other) + return self._from_list([a for a in self._list if a in other_set]) + + def __and__(self, other): + return self.intersection(other) + + def symmetric_difference(self, other): + cdef set other_set = self._to_set(other) + result = self._from_list([a for a in self._list if a not in other_set]) + # use other here to keep the order + result.update(a for a in other if a not in self) + return result + + def __xor__(self, other): + return self.symmetric_difference(other) + + def difference(self, other): + cdef set other_set = self._to_set(other) + return self._from_list([a for a in self._list if a not in other_set]) + + def __sub__(self, other): + return self.difference(other) + + def intersection_update(self, other): + cdef set other_set = self._to_set(other) + set.intersection_update(self, other_set) + self._list = [a for a in self._list if a in other_set] + return self + + def __iand__(self, other): + return self.intersection_update(other) + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [a for a in self._list if a in self] + self._list += [a for a in other if a in self] + return self + + def __ixor__(self, other): + return self.symmetric_difference_update(other) + + def difference_update(self, other): + set.difference_update(self, other) + self._list = [a for a in self._list if a in self] + return self + + def __isub__(self, other): + return self.difference_update(other) + + +cdef object cy_id(object item): + return PyLong_FromLong(<long> (<void *>item)) + +# NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped +# instead of the __rmeth__, so they need to check that also self is of the +# correct type. This is fixed in cython 3.x. See: +# https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods + +cdef class IdentitySet: + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + cdef dict _members + + def __init__(self, iterable=None): + self._members = {} + if iterable: + self.update(iterable) + + def add(self, value): + self._members[cy_id(value)] = value + + def __contains__(self, value): + return cy_id(value) in self._members + + cpdef remove(self, value): + del self._members[cy_id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + cdef tuple pair + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError("pop from an empty set") + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError("cannot compare sets using cmp()") + + def __eq__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members == other_._members + else: + return False + + def __ne__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members != other_._members + else: + return True + + cpdef issubset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) > len(other): + return False + for m in filterfalse(other._members.__contains__, self._members): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + cpdef issuperset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) < len(other): + return False + for m in filterfalse(self._members.__contains__, other._members): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + cpdef IdentitySet union(self, iterable): + cdef IdentitySet result = self.__class__() + result._members.update(self._members) + result.update(iterable) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.union(other) + + cpdef update(self, iterable): + for obj in iterable: + self._members[cy_id(obj)] = obj + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + cpdef difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (<IdentitySet>iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k:v for k, v in self._members.items() if k not in other} + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.difference(other) + + cpdef difference_update(self, iterable): + cdef IdentitySet other = self.difference(iterable) + self._members = other._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + cpdef intersection(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (<IdentitySet>iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k in other} + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.intersection(other) + + cpdef intersection_update(self, iterable): + cdef IdentitySet other = self.intersection(iterable) + self._members = other._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + cpdef symmetric_difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + cdef dict other + if isinstance(iterable, self.__class__): + other = (<IdentitySet>iterable)._members + else: + other = {cy_id(obj): obj for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k not in other} + result._members.update( + [(k, v) for k, v in other.items() if k not in self._members] + ) + return result + + def __xor__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + cpdef symmetric_difference_update(self, iterable): + cdef IdentitySet other = self.symmetric_difference(iterable) + self._members = other._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + cpdef copy(self): + cdef IdentitySet cp = self.__new__(self.__class__) + cp._members = self._members.copy() + return cp + + def __copy__(self): + return self.copy() + + def __len__(self): + return len(self._members) + + def __iter__(self): + return iter(self._members.values()) + + def __hash__(self): + raise TypeError("set objects are unhashable") + + def __repr__(self): + return "%s(%r)" % (type(self).__name__, list(self._members.values())) diff --git a/lib/sqlalchemy/cyextension/immutabledict.pxd b/lib/sqlalchemy/cyextension/immutabledict.pxd new file mode 100644 index 000000000..fe7ad6a81 --- /dev/null +++ b/lib/sqlalchemy/cyextension/immutabledict.pxd @@ -0,0 +1,2 @@ +cdef class immutabledict(dict): + pass diff --git a/lib/sqlalchemy/cyextension/immutabledict.pyx b/lib/sqlalchemy/cyextension/immutabledict.pyx new file mode 100644 index 000000000..89bcf3ed6 --- /dev/null +++ b/lib/sqlalchemy/cyextension/immutabledict.pyx @@ -0,0 +1,100 @@ +from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size + + +def _immutable_fn(obj): + raise TypeError("%s object is immutable" % obj.__class__.__name__) + + +class ImmutableContainer: + def _immutable(self, *a,**kw): + _immutable_fn(self) + + __delitem__ = __setitem__ = __setattr__ = _immutable + + +cdef class immutabledict(dict): + def __repr__(self): + return f"immutabledict({dict.__repr__(self)})" + + def union(self, *args, **kw): + cdef dict to_merge = None + cdef immutabledict result + cdef Py_ssize_t args_len = len(args) + if args_len > 1: + raise TypeError( + f'union expected at most 1 argument, got {args_len}' + ) + if args_len == 1: + attribute = args[0] + if isinstance(attribute, dict): + to_merge = <dict> attribute + if to_merge is None: + to_merge = dict(*args, **kw) + + if PyDict_Size(to_merge) == 0: + return self + + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update(result, to_merge) + return result + + def merge_with(self, *other): + cdef immutabledict result = None + cdef object d + cdef bint update = False + if not other: + return self + for d in other: + if d: + if update == False: + update = True + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update( + result, <dict>(d if isinstance(d, dict) else dict(d)) + ) + + return self if update == False else result + + def copy(self): + return self + + def __reduce__(self): + return immutabledict, (dict(self), ) + + def __delitem__(self, k): + _immutable_fn(self) + + def __setitem__(self, k, v): + _immutable_fn(self) + + def __setattr__(self, k, v): + _immutable_fn(self) + + def clear(self, *args, **kw): + _immutable_fn(self) + + def pop(self, *args, **kw): + _immutable_fn(self) + + def popitem(self, *args, **kw): + _immutable_fn(self) + + def setdefault(self, *args, **kw): + _immutable_fn(self) + + def update(self, *args, **kw): + _immutable_fn(self) + + # PEP 584 + def __ior__(self, other): + _immutable_fn(self) + + def __or__(self, other): + return immutabledict(super().__or__(other)) + + def __ror__(self, other): + return immutabledict(super().__ror__(other)) diff --git a/lib/sqlalchemy/cyextension/processors.pyx b/lib/sqlalchemy/cyextension/processors.pyx new file mode 100644 index 000000000..9f23e73b1 --- /dev/null +++ b/lib/sqlalchemy/cyextension/processors.pyx @@ -0,0 +1,91 @@ +import datetime +import re + +from cpython.datetime cimport date_new, datetime_new, import_datetime, time_new +from cpython.object cimport PyObject_Str +from cpython.unicode cimport PyUnicode_AsASCIIString, PyUnicode_Check, PyUnicode_Decode +from libc.stdio cimport sscanf + + +def int_to_boolean(value): + if value is None: + return None + return True if value else False + +def to_str(value): + return PyObject_Str(value) if value is not None else None + +def to_float(value): + return float(value) if value is not None else None + +cdef inline bytes to_bytes(object value, str type_name): + try: + return PyUnicode_AsASCIIString(value) + except Exception as e: + raise ValueError( + f"Couldn't parse {type_name} string '{value!r}' " + "- value is not a string." + ) from e + +import_datetime() # required to call datetime_new/date_new/time_new + +def str_to_datetime(value): + if value is None: + return None + cdef int numparsed + cdef unsigned int year, month, day, hour, minute, second, microsecond = 0 + cdef bytes value_b = to_bytes(value, 'datetime') + cdef const char * string = value_b + + numparsed = sscanf(string, "%4u-%2u-%2u %2u:%2u:%2u.%6u", + &year, &month, &day, &hour, &minute, &second, µsecond) + if numparsed < 6: + raise ValueError( + "Couldn't parse datetime string: '%s'" % (value) + ) + return datetime_new(year, month, day, hour, minute, second, microsecond, None) + +def str_to_date(value): + if value is None: + return None + cdef int numparsed + cdef unsigned int year, month, day + cdef bytes value_b = to_bytes(value, 'date') + cdef const char * string = value_b + + numparsed = sscanf(string, "%4u-%2u-%2u", &year, &month, &day) + if numparsed != 3: + raise ValueError( + "Couldn't parse date string: '%s'" % (value) + ) + return date_new(year, month, day) + +def str_to_time(value): + if value is None: + return None + cdef int numparsed + cdef unsigned int hour, minute, second, microsecond = 0 + cdef bytes value_b = to_bytes(value, 'time') + cdef const char * string = value_b + + numparsed = sscanf(string, "%2u:%2u:%2u.%6u", &hour, &minute, &second, µsecond) + if numparsed < 3: + raise ValueError( + "Couldn't parse time string: '%s'" % (value) + ) + return time_new(hour, minute, second, microsecond, None) + + +cdef class DecimalResultProcessor: + cdef object type_ + cdef str format_ + + def __cinit__(self, type_, format_): + self.type_ = type_ + self.format_ = format_ + + def process(self, object value): + if value is None: + return None + else: + return self.type_(self.format_ % value) diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx new file mode 100644 index 000000000..daf5cc940 --- /dev/null +++ b/lib/sqlalchemy/cyextension/resultproxy.pyx @@ -0,0 +1,130 @@ +# TODO: this is mostly just copied over from the python implementation +# more improvements are likely possible +import operator + +cdef int MD_INDEX = 0 # integer index in cursor.description + +KEY_INTEGER_ONLY = 0 +KEY_OBJECTS_ONLY = 1 + +sqlalchemy_engine_row = None + +cdef class BaseRow: + cdef readonly object _parent + cdef readonly tuple _data + cdef readonly dict _keymap + cdef readonly int _key_style + + def __init__(self, object parent, object processors, dict keymap, int key_style, object data): + """Row objects are constructed by CursorResult objects.""" + + self._parent = parent + + if processors: + self._data = tuple( + [ + proc(value) if proc else value + for proc, value in zip(processors, data) + ] + ) + else: + self._data = tuple(data) + + self._keymap = keymap + + self._key_style = key_style + + def __reduce__(self): + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self): + return { + "_parent": self._parent, + "_data": self._data, + "_key_style": self._key_style, + } + + def __setstate__(self, dict state): + self._parent = state["_parent"] + self._data = state["_data"] + self._keymap = self._parent._keymap + self._key_style = state["_key_style"] + + def _filter_on_values(self, filters): + global sqlalchemy_engine_row + if sqlalchemy_engine_row is None: + from sqlalchemy.engine.row import Row as sqlalchemy_engine_row + + return sqlalchemy_engine_row( + self._parent, + filters, + self._keymap, + self._key_style, + self._data, + ) + + def _values_impl(self): + return list(self) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __hash__(self): + return hash(self._data) + + def _get_by_int_impl(self, key): + return self._data[key] + + cpdef _get_by_key_impl(self, key): + # keep two isinstance since it's noticeably faster in the int case + if isinstance(key, int) or isinstance(key, slice): + return self._data[key] + + self._parent._raise_for_nonint(key) + + def __getitem__(self, key): + return self._get_by_key_impl(key) + + cpdef _get_by_key_impl_mapping(self, key): + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._parent._key_fallback(key, ke) + + mdindex = rec[MD_INDEX] + if mdindex is None: + self._parent._raise_for_ambiguous_column_name(rec) + elif ( + self._key_style == KEY_OBJECTS_ONLY + and isinstance(key, int) + ): + raise KeyError(key) + + return self._data[mdindex] + + def __getattr__(self, name): + try: + return self._get_by_key_impl_mapping(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + +def tuplegetter(*indexes): + it = operator.itemgetter(*indexes) + + if len(indexes) > 1: + return it + else: + return lambda row: (it(row),) diff --git a/lib/sqlalchemy/cyextension/util.pyx b/lib/sqlalchemy/cyextension/util.pyx new file mode 100644 index 000000000..ac15ff9de --- /dev/null +++ b/lib/sqlalchemy/cyextension/util.pyx @@ -0,0 +1,43 @@ +from collections.abc import Mapping + +from sqlalchemy import exc + +cdef tuple _Empty_Tuple = () + +cdef inline bint _mapping_or_tuple(object value): + return isinstance(value, dict) or isinstance(value, tuple) or isinstance(value, Mapping) + +cdef inline bint _check_item(object params) except 0: + cdef object item + cdef bint ret = 1 + if params: + item = params[0] + if not _mapping_or_tuple(item): + ret = 0 + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + return ret + +def _distill_params_20(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list) or isinstance(params, tuple): + _check_item(params) + return params + elif isinstance(params, dict) or isinstance(params, Mapping): + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list): + _check_item(params) + return params + elif _mapping_or_tuple(params): + return [params] + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") |