diff options
Diffstat (limited to 'lib/sqlalchemy/util/_collections.py')
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 122 |
1 files changed, 70 insertions, 52 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 43440134a..67be0e6bf 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -10,8 +10,13 @@ from __future__ import absolute_import import weakref import operator -from .compat import threading, itertools_filterfalse, string_types, \ - binary_types, collections_abc +from .compat import ( + threading, + itertools_filterfalse, + string_types, + binary_types, + collections_abc, +) from . import py2k import types @@ -77,7 +82,7 @@ class KeyedTuple(AbstractKeyedTuple): t.__dict__.update(zip(labels, vals)) else: labels = [] - t.__dict__['_labels'] = labels + t.__dict__["_labels"] = labels return t @property @@ -139,8 +144,7 @@ class ImmutableContainer(object): class immutabledict(ImmutableContainer, dict): - clear = pop = popitem = setdefault = \ - update = ImmutableContainer._immutable + clear = pop = popitem = setdefault = update = ImmutableContainer._immutable def __new__(cls, *args): new = dict.__new__(cls) @@ -151,7 +155,7 @@ class immutabledict(ImmutableContainer, dict): pass def __reduce__(self): - return immutabledict, (dict(self), ) + return immutabledict, (dict(self),) def union(self, d): if not d: @@ -173,10 +177,10 @@ class immutabledict(ImmutableContainer, dict): class Properties(object): """Provide a __getattr__/__setattr__ interface over a dict.""" - __slots__ = '_data', + __slots__ = ("_data",) def __init__(self, data): - object.__setattr__(self, '_data', data) + object.__setattr__(self, "_data", data) def __len__(self): return len(self._data) @@ -185,7 +189,9 @@ class Properties(object): return iter(list(self._data.values())) def __dir__(self): - return dir(super(Properties, self)) + [str(k) for k in self._data.keys()] + return dir(super(Properties, self)) + [ + str(k) for k in self._data.keys() + ] def __add__(self, other): return list(self) + list(other) @@ -203,10 +209,10 @@ class Properties(object): self._data[key] = obj def __getstate__(self): - return {'_data': self._data} + return {"_data": self._data} def __setstate__(self, state): - object.__setattr__(self, '_data', state['_data']) + object.__setattr__(self, "_data", state["_data"]) def __getattr__(self, key): try: @@ -266,7 +272,7 @@ class ImmutableProperties(ImmutableContainer, Properties): class OrderedDict(dict): """A dict that returns keys/values/items in the order they were added.""" - __slots__ = '_list', + __slots__ = ("_list",) def __reduce__(self): return OrderedDict, (self.items(),) @@ -294,7 +300,7 @@ class OrderedDict(dict): def update(self, ____sequence=None, **kwargs): if ____sequence is not None: - if hasattr(____sequence, 'keys'): + if hasattr(____sequence, "keys"): for key in ____sequence.keys(): self.__setitem__(key, ____sequence[key]) else: @@ -323,6 +329,7 @@ class OrderedDict(dict): return [(key, self[key]) for key in self._list] if py2k: + def itervalues(self): return iter(self.values()) @@ -402,7 +409,7 @@ class OrderedSet(set): return self.union(other) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self._list) + return "%s(%r)" % (self.__class__.__name__, self._list) __str__ = __repr__ @@ -502,13 +509,13 @@ class IdentitySet(object): pair = self._members.popitem() return pair[1] except KeyError: - raise KeyError('pop from an empty set') + raise KeyError("pop from an empty set") def clear(self): self._members.clear() def __cmp__(self, other): - raise TypeError('cannot compare sets using cmp()') + raise TypeError("cannot compare sets using cmp()") def __eq__(self, other): if isinstance(other, IdentitySet): @@ -527,8 +534,9 @@ class IdentitySet(object): if len(self) > len(other): return False - for m in itertools_filterfalse(other._members.__contains__, - iter(self._members.keys())): + for m in itertools_filterfalse( + other._members.__contains__, iter(self._members.keys()) + ): return False return True @@ -548,8 +556,9 @@ class IdentitySet(object): if len(self) < len(other): return False - for m in itertools_filterfalse(self._members.__contains__, - iter(other._members.keys())): + for m in itertools_filterfalse( + self._members.__contains__, iter(other._members.keys()) + ): return False return True @@ -635,7 +644,8 @@ class IdentitySet(object): members = self._member_id_tuples() other = _iter_id(iterable) result._members.update( - self._working_set(members).symmetric_difference(other)) + self._working_set(members).symmetric_difference(other) + ) return result def _member_id_tuples(self): @@ -667,10 +677,10 @@ class IdentitySet(object): return iter(self._members.values()) def __hash__(self): - raise TypeError('set objects are unhashable') + raise TypeError("set objects are unhashable") def __repr__(self): - return '%s(%r)' % (type(self).__name__, list(self._members.values())) + return "%s(%r)" % (type(self).__name__, list(self._members.values())) class WeakSequence(object): @@ -689,8 +699,9 @@ class WeakSequence(object): return len(self._storage) def __iter__(self): - return (obj for obj in - (ref() for ref in self._storage) if obj is not None) + return ( + obj for obj in (ref() for ref in self._storage) if obj is not None + ) def __getitem__(self, index): try: @@ -732,6 +743,7 @@ class PopulateDict(dict): self[key] = val = self.creator(key) return val + # Define collections that are capable of storing # ColumnElement objects as hashable keys/elements. # At this point, these are mostly historical, things @@ -745,20 +757,21 @@ populate_column_dict = PopulateDict _getters = PopulateDict(operator.itemgetter) _property_getters = PopulateDict( - lambda idx: property(operator.itemgetter(idx))) + lambda idx: property(operator.itemgetter(idx)) +) def unique_list(seq, hashfunc=None): seen = set() seen_add = seen.add if not hashfunc: - return [x for x in seq - if x not in seen - and not seen_add(x)] + return [x for x in seq if x not in seen and not seen_add(x)] else: - return [x for x in seq - if hashfunc(x) not in seen - and not seen_add(hashfunc(x))] + return [ + x + for x in seq + if hashfunc(x) not in seen and not seen_add(hashfunc(x)) + ] class UniqueAppender(object): @@ -773,9 +786,9 @@ class UniqueAppender(object): self._unique = {} if via: self._data_appender = getattr(data, via) - elif hasattr(data, 'append'): + elif hasattr(data, "append"): self._data_appender = data.append - elif hasattr(data, 'add'): + elif hasattr(data, "add"): self._data_appender = data.add def append(self, item): @@ -798,8 +811,9 @@ def coerce_generator_arg(arg): def to_list(x, default=None): if x is None: return default - if not isinstance(x, collections_abc.Iterable) or \ - isinstance(x, string_types + binary_types): + if not isinstance(x, collections_abc.Iterable) or isinstance( + x, string_types + binary_types + ): return [x] elif isinstance(x, list): return x @@ -815,9 +829,7 @@ def has_intersection(set_, iterable): """ # TODO: optimize, write in C, etc. - return bool( - set_.intersection([i for i in iterable if i.__hash__]) - ) + return bool(set_.intersection([i for i in iterable if i.__hash__])) def to_set(x): @@ -854,7 +866,7 @@ def flatten_iterator(x): """ for elem in x: - if not isinstance(elem, str) and hasattr(elem, '__iter__'): + if not isinstance(elem, str) and hasattr(elem, "__iter__"): for y in flatten_iterator(elem): yield y else: @@ -871,9 +883,9 @@ class LRUCache(dict): """ - __slots__ = 'capacity', 'threshold', 'size_alert', '_counter', '_mutex' + __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex" - def __init__(self, capacity=100, threshold=.5, size_alert=None): + def __init__(self, capacity=100, threshold=0.5, size_alert=None): self.capacity = capacity self.threshold = threshold self.size_alert = size_alert @@ -929,10 +941,10 @@ class LRUCache(dict): if size_alert: size_alert = False self.size_alert(self) - by_counter = sorted(dict.values(self), - key=operator.itemgetter(2), - reverse=True) - for item in by_counter[self.capacity:]: + by_counter = sorted( + dict.values(self), key=operator.itemgetter(2), reverse=True + ) + for item in by_counter[self.capacity :]: try: del self[item[0]] except KeyError: @@ -946,17 +958,22 @@ _lw_tuples = LRUCache(100) def lightweight_named_tuple(name, fields): - hash_ = (name, ) + tuple(fields) + hash_ = (name,) + tuple(fields) tp_cls = _lw_tuples.get(hash_) if tp_cls: return tp_cls tp_cls = type( - name, (_LW,), - dict([ - (field, _property_getters[idx]) - for idx, field in enumerate(fields) if field is not None - ] + [('__slots__', ())]) + name, + (_LW,), + dict( + [ + (field, _property_getters[idx]) + for idx, field in enumerate(fields) + if field is not None + ] + + [("__slots__", ())] + ), ) tp_cls._real_fields = fields @@ -1077,6 +1094,7 @@ def has_dupes(sequence, target): return True return False + # .index version. the two __contains__ calls as well # as .index() and isinstance() slow this down. # def has_dupes(sequence, target): |
