diff options
| author | Federico Caselli <cfederico87@gmail.com> | 2023-03-14 23:17:07 +0100 |
|---|---|---|
| committer | Federico Caselli <cfederico87@gmail.com> | 2023-03-30 22:18:11 +0200 |
| commit | a979b6dc5ebefedfd8c85f5695cc5be8882eaa29 (patch) | |
| tree | 8af2f9102fa109b0fa968cada17004e3d2b41e5f /lib | |
| parent | 77357be824095b46eb2ed3206bc555a6dacc7f30 (diff) | |
| download | sqlalchemy-a979b6dc5ebefedfd8c85f5695cc5be8882eaa29.tar.gz | |
Add missing methods to OrderedSet.
Implemented missing method ``copy`` and ``pop`` in OrderedSet class.
Fixes: #9487
Change-Id: I1d2278b64939b44422e9d5857ec7d345fff53997
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/cyextension/collections.pyx | 65 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_py_collections.py | 37 |
2 files changed, 64 insertions, 38 deletions
diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index e6667dddd..d08fa3aab 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -1,8 +1,9 @@ cimport cython from cpython.dict cimport PyDict_Merge, PyDict_Update -from cpython.long cimport PyLong_FromLong +from cpython.long cimport PyLong_FromLongLong from cpython.set cimport PySet_Add +from collections.abc import Collection from itertools import filterfalse cdef bint add_not_present(set seen, object item, hashfunc): @@ -39,8 +40,7 @@ cdef class OrderedSet(set): else: self._list = [] - @cython.final - cdef OrderedSet _copy(self): + cpdef OrderedSet copy(self): cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) cp._list = list(self._list) set.update(cp, cp._list) @@ -63,6 +63,14 @@ cdef class OrderedSet(set): set.remove(self, element) self._list.remove(element) + def pop(self): + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + set.remove(self, value) + return value + def insert(self, Py_ssize_t pos, element): if element not in self: self._list.insert(pos, element) @@ -91,34 +99,25 @@ cdef class OrderedSet(set): __str__ = __repr__ - cpdef OrderedSet update(self, iterable): - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - return self + def update(self, *iterables): + for iterable in iterables: + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) def __ior__(self, iterable): - return self.update(iterable) + self.update(iterable) + return self def union(self, *other): - result = self._copy() - for o in other: - result.update(o) + result = self.copy() + result.update(*other) return result def __or__(self, other): return self.union(other) - @cython.final - cdef set _to_set(self, other): - cdef set other_set - if isinstance(other, set): - other_set = <set> other - else: - other_set = set(other) - return other_set - def intersection(self, *other): cdef set other_set = set.intersection(self, *other) return self._from_list([a for a in self._list if a in other_set]) @@ -127,10 +126,18 @@ cdef class OrderedSet(set): return self.intersection(other) def symmetric_difference(self, other): - cdef set other_set = self._to_set(other) + cdef set other_set + if isinstance(other, set): + other_set = <set> other + collection = other_set + elif isinstance(other, Collection): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) result = self._from_list([a for a in self._list if a not in other_set]) - # use other here to keep the order - result.update(a for a in other if a not in self) + result.update(a for a in collection if a not in self) return result def __xor__(self, other): @@ -152,9 +159,10 @@ cdef class OrderedSet(set): return self cpdef symmetric_difference_update(self, other): - set.symmetric_difference_update(self, other) + collection = other if isinstance(other, Collection) else list(other) + set.symmetric_difference_update(self, collection) self._list = [a for a in self._list if a in self] - self._list += [a for a in other if a in self] + self._list += [a for a in collection if a in self] def __ixor__(self, other): self.symmetric_difference_update(other) @@ -169,13 +177,12 @@ cdef class OrderedSet(set): return self cdef object cy_id(object item): - return PyLong_FromLong(<long> (<void *>item)) + return PyLong_FromLongLong(<long long> (<void *>item)) # NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped # instead of the __rmeth__, so they need to check that also self is of the # correct type. This is fixed in cython 3.x. See: # https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods - cdef class IdentitySet: """A set that considers only object id() for uniqueness. diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 8810800c4..9962493b5 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -168,8 +168,11 @@ class OrderedSet(Set[_T]): else: self._list = [] - def __reduce__(self): - return (OrderedSet, (self._list,)) + def copy(self) -> OrderedSet[_T]: + cp = self.__class__() + cp._list = self._list.copy() + set.update(cp, cp._list) + return cp def add(self, element: _T) -> None: if element not in self: @@ -180,6 +183,14 @@ class OrderedSet(Set[_T]): super().remove(element) self._list.remove(element) + def pop(self) -> _T: + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + super().remove(value) + return value + def insert(self, pos: int, element: _T) -> None: if element not in self: self._list.insert(pos, element) @@ -220,9 +231,8 @@ class OrderedSet(Set[_T]): return self # type: ignore def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]: - result: OrderedSet[Union[_T, _S]] = self.__class__(self) # type: ignore # noqa: E501 - for o in other: - result.update(o) + result: OrderedSet[Union[_T, _S]] = self.copy() # type: ignore + result.update(*other) return result def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: @@ -237,9 +247,17 @@ class OrderedSet(Set[_T]): return self.intersection(other) def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]: - other_set = other if isinstance(other, set) else set(other) + collection: Collection[_T] + if isinstance(other, set): + collection = other_set = other + elif isinstance(other, Collection): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) result = self.__class__(a for a in self if a not in other_set) - result.update(a for a in other if a not in self) + result.update(a for a in collection if a not in self) return result def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: @@ -263,9 +281,10 @@ class OrderedSet(Set[_T]): return self def symmetric_difference_update(self, other: Iterable[Any]) -> None: - super().symmetric_difference_update(other) + collection = other if isinstance(other, Collection) else list(other) + super().symmetric_difference_update(collection) self._list = [a for a in self._list if a in self] - self._list += [a for a in other if a in self] + self._list += [a for a in collection if a in self] def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]: self.symmetric_difference_update(other) |
