diff options
Diffstat (limited to 'lib/sqlalchemy/cyextension')
| -rw-r--r-- | lib/sqlalchemy/cyextension/collections.pyx | 65 |
1 files changed, 36 insertions, 29 deletions
diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index e6667dddd..d08fa3aab 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -1,8 +1,9 @@ cimport cython from cpython.dict cimport PyDict_Merge, PyDict_Update -from cpython.long cimport PyLong_FromLong +from cpython.long cimport PyLong_FromLongLong from cpython.set cimport PySet_Add +from collections.abc import Collection from itertools import filterfalse cdef bint add_not_present(set seen, object item, hashfunc): @@ -39,8 +40,7 @@ cdef class OrderedSet(set): else: self._list = [] - @cython.final - cdef OrderedSet _copy(self): + cpdef OrderedSet copy(self): cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) cp._list = list(self._list) set.update(cp, cp._list) @@ -63,6 +63,14 @@ cdef class OrderedSet(set): set.remove(self, element) self._list.remove(element) + def pop(self): + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + set.remove(self, value) + return value + def insert(self, Py_ssize_t pos, element): if element not in self: self._list.insert(pos, element) @@ -91,34 +99,25 @@ cdef class OrderedSet(set): __str__ = __repr__ - cpdef OrderedSet update(self, iterable): - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - return self + def update(self, *iterables): + for iterable in iterables: + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) def __ior__(self, iterable): - return self.update(iterable) + self.update(iterable) + return self def union(self, *other): - result = self._copy() - for o in other: - result.update(o) + result = self.copy() + result.update(*other) return result def __or__(self, other): return self.union(other) - @cython.final - cdef set _to_set(self, other): - cdef set other_set - if isinstance(other, set): - other_set = <set> other - else: - other_set = set(other) - return other_set - def intersection(self, *other): cdef set other_set = set.intersection(self, *other) return self._from_list([a for a in self._list if a in other_set]) @@ -127,10 +126,18 @@ cdef class OrderedSet(set): return self.intersection(other) def symmetric_difference(self, other): - cdef set other_set = self._to_set(other) + cdef set other_set + if isinstance(other, set): + other_set = <set> other + collection = other_set + elif isinstance(other, Collection): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) result = self._from_list([a for a in self._list if a not in other_set]) - # use other here to keep the order - result.update(a for a in other if a not in self) + result.update(a for a in collection if a not in self) return result def __xor__(self, other): @@ -152,9 +159,10 @@ cdef class OrderedSet(set): return self cpdef symmetric_difference_update(self, other): - set.symmetric_difference_update(self, other) + collection = other if isinstance(other, Collection) else list(other) + set.symmetric_difference_update(self, collection) self._list = [a for a in self._list if a in self] - self._list += [a for a in other if a in self] + self._list += [a for a in collection if a in self] def __ixor__(self, other): self.symmetric_difference_update(other) @@ -169,13 +177,12 @@ cdef class OrderedSet(set): return self cdef object cy_id(object item): - return PyLong_FromLong(<long> (<void *>item)) + return PyLong_FromLongLong(<long long> (<void *>item)) # NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped # instead of the __rmeth__, so they need to check that also self is of the # correct type. This is fixed in cython 3.x. See: # https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods - cdef class IdentitySet: """A set that considers only object id() for uniqueness. |
