summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util/_collections.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util/_collections.py')
-rw-r--r--lib/sqlalchemy/util/_collections.py122
1 files changed, 70 insertions, 52 deletions
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
index 43440134a..67be0e6bf 100644
--- a/lib/sqlalchemy/util/_collections.py
+++ b/lib/sqlalchemy/util/_collections.py
@@ -10,8 +10,13 @@
from __future__ import absolute_import
import weakref
import operator
-from .compat import threading, itertools_filterfalse, string_types, \
- binary_types, collections_abc
+from .compat import (
+ threading,
+ itertools_filterfalse,
+ string_types,
+ binary_types,
+ collections_abc,
+)
from . import py2k
import types
@@ -77,7 +82,7 @@ class KeyedTuple(AbstractKeyedTuple):
t.__dict__.update(zip(labels, vals))
else:
labels = []
- t.__dict__['_labels'] = labels
+ t.__dict__["_labels"] = labels
return t
@property
@@ -139,8 +144,7 @@ class ImmutableContainer(object):
class immutabledict(ImmutableContainer, dict):
- clear = pop = popitem = setdefault = \
- update = ImmutableContainer._immutable
+ clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
def __new__(cls, *args):
new = dict.__new__(cls)
@@ -151,7 +155,7 @@ class immutabledict(ImmutableContainer, dict):
pass
def __reduce__(self):
- return immutabledict, (dict(self), )
+ return immutabledict, (dict(self),)
def union(self, d):
if not d:
@@ -173,10 +177,10 @@ class immutabledict(ImmutableContainer, dict):
class Properties(object):
"""Provide a __getattr__/__setattr__ interface over a dict."""
- __slots__ = '_data',
+ __slots__ = ("_data",)
def __init__(self, data):
- object.__setattr__(self, '_data', data)
+ object.__setattr__(self, "_data", data)
def __len__(self):
return len(self._data)
@@ -185,7 +189,9 @@ class Properties(object):
return iter(list(self._data.values()))
def __dir__(self):
- return dir(super(Properties, self)) + [str(k) for k in self._data.keys()]
+ return dir(super(Properties, self)) + [
+ str(k) for k in self._data.keys()
+ ]
def __add__(self, other):
return list(self) + list(other)
@@ -203,10 +209,10 @@ class Properties(object):
self._data[key] = obj
def __getstate__(self):
- return {'_data': self._data}
+ return {"_data": self._data}
def __setstate__(self, state):
- object.__setattr__(self, '_data', state['_data'])
+ object.__setattr__(self, "_data", state["_data"])
def __getattr__(self, key):
try:
@@ -266,7 +272,7 @@ class ImmutableProperties(ImmutableContainer, Properties):
class OrderedDict(dict):
"""A dict that returns keys/values/items in the order they were added."""
- __slots__ = '_list',
+ __slots__ = ("_list",)
def __reduce__(self):
return OrderedDict, (self.items(),)
@@ -294,7 +300,7 @@ class OrderedDict(dict):
def update(self, ____sequence=None, **kwargs):
if ____sequence is not None:
- if hasattr(____sequence, 'keys'):
+ if hasattr(____sequence, "keys"):
for key in ____sequence.keys():
self.__setitem__(key, ____sequence[key])
else:
@@ -323,6 +329,7 @@ class OrderedDict(dict):
return [(key, self[key]) for key in self._list]
if py2k:
+
def itervalues(self):
return iter(self.values())
@@ -402,7 +409,7 @@ class OrderedSet(set):
return self.union(other)
def __repr__(self):
- return '%s(%r)' % (self.__class__.__name__, self._list)
+ return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
@@ -502,13 +509,13 @@ class IdentitySet(object):
pair = self._members.popitem()
return pair[1]
except KeyError:
- raise KeyError('pop from an empty set')
+ raise KeyError("pop from an empty set")
def clear(self):
self._members.clear()
def __cmp__(self, other):
- raise TypeError('cannot compare sets using cmp()')
+ raise TypeError("cannot compare sets using cmp()")
def __eq__(self, other):
if isinstance(other, IdentitySet):
@@ -527,8 +534,9 @@ class IdentitySet(object):
if len(self) > len(other):
return False
- for m in itertools_filterfalse(other._members.__contains__,
- iter(self._members.keys())):
+ for m in itertools_filterfalse(
+ other._members.__contains__, iter(self._members.keys())
+ ):
return False
return True
@@ -548,8 +556,9 @@ class IdentitySet(object):
if len(self) < len(other):
return False
- for m in itertools_filterfalse(self._members.__contains__,
- iter(other._members.keys())):
+ for m in itertools_filterfalse(
+ self._members.__contains__, iter(other._members.keys())
+ ):
return False
return True
@@ -635,7 +644,8 @@ class IdentitySet(object):
members = self._member_id_tuples()
other = _iter_id(iterable)
result._members.update(
- self._working_set(members).symmetric_difference(other))
+ self._working_set(members).symmetric_difference(other)
+ )
return result
def _member_id_tuples(self):
@@ -667,10 +677,10 @@ class IdentitySet(object):
return iter(self._members.values())
def __hash__(self):
- raise TypeError('set objects are unhashable')
+ raise TypeError("set objects are unhashable")
def __repr__(self):
- return '%s(%r)' % (type(self).__name__, list(self._members.values()))
+ return "%s(%r)" % (type(self).__name__, list(self._members.values()))
class WeakSequence(object):
@@ -689,8 +699,9 @@ class WeakSequence(object):
return len(self._storage)
def __iter__(self):
- return (obj for obj in
- (ref() for ref in self._storage) if obj is not None)
+ return (
+ obj for obj in (ref() for ref in self._storage) if obj is not None
+ )
def __getitem__(self, index):
try:
@@ -732,6 +743,7 @@ class PopulateDict(dict):
self[key] = val = self.creator(key)
return val
+
# Define collections that are capable of storing
# ColumnElement objects as hashable keys/elements.
# At this point, these are mostly historical, things
@@ -745,20 +757,21 @@ populate_column_dict = PopulateDict
_getters = PopulateDict(operator.itemgetter)
_property_getters = PopulateDict(
- lambda idx: property(operator.itemgetter(idx)))
+ lambda idx: property(operator.itemgetter(idx))
+)
def unique_list(seq, hashfunc=None):
seen = set()
seen_add = seen.add
if not hashfunc:
- return [x for x in seq
- if x not in seen
- and not seen_add(x)]
+ return [x for x in seq if x not in seen and not seen_add(x)]
else:
- return [x for x in seq
- if hashfunc(x) not in seen
- and not seen_add(hashfunc(x))]
+ return [
+ x
+ for x in seq
+ if hashfunc(x) not in seen and not seen_add(hashfunc(x))
+ ]
class UniqueAppender(object):
@@ -773,9 +786,9 @@ class UniqueAppender(object):
self._unique = {}
if via:
self._data_appender = getattr(data, via)
- elif hasattr(data, 'append'):
+ elif hasattr(data, "append"):
self._data_appender = data.append
- elif hasattr(data, 'add'):
+ elif hasattr(data, "add"):
self._data_appender = data.add
def append(self, item):
@@ -798,8 +811,9 @@ def coerce_generator_arg(arg):
def to_list(x, default=None):
if x is None:
return default
- if not isinstance(x, collections_abc.Iterable) or \
- isinstance(x, string_types + binary_types):
+ if not isinstance(x, collections_abc.Iterable) or isinstance(
+ x, string_types + binary_types
+ ):
return [x]
elif isinstance(x, list):
return x
@@ -815,9 +829,7 @@ def has_intersection(set_, iterable):
"""
# TODO: optimize, write in C, etc.
- return bool(
- set_.intersection([i for i in iterable if i.__hash__])
- )
+ return bool(set_.intersection([i for i in iterable if i.__hash__]))
def to_set(x):
@@ -854,7 +866,7 @@ def flatten_iterator(x):
"""
for elem in x:
- if not isinstance(elem, str) and hasattr(elem, '__iter__'):
+ if not isinstance(elem, str) and hasattr(elem, "__iter__"):
for y in flatten_iterator(elem):
yield y
else:
@@ -871,9 +883,9 @@ class LRUCache(dict):
"""
- __slots__ = 'capacity', 'threshold', 'size_alert', '_counter', '_mutex'
+ __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
- def __init__(self, capacity=100, threshold=.5, size_alert=None):
+ def __init__(self, capacity=100, threshold=0.5, size_alert=None):
self.capacity = capacity
self.threshold = threshold
self.size_alert = size_alert
@@ -929,10 +941,10 @@ class LRUCache(dict):
if size_alert:
size_alert = False
self.size_alert(self)
- by_counter = sorted(dict.values(self),
- key=operator.itemgetter(2),
- reverse=True)
- for item in by_counter[self.capacity:]:
+ by_counter = sorted(
+ dict.values(self), key=operator.itemgetter(2), reverse=True
+ )
+ for item in by_counter[self.capacity :]:
try:
del self[item[0]]
except KeyError:
@@ -946,17 +958,22 @@ _lw_tuples = LRUCache(100)
def lightweight_named_tuple(name, fields):
- hash_ = (name, ) + tuple(fields)
+ hash_ = (name,) + tuple(fields)
tp_cls = _lw_tuples.get(hash_)
if tp_cls:
return tp_cls
tp_cls = type(
- name, (_LW,),
- dict([
- (field, _property_getters[idx])
- for idx, field in enumerate(fields) if field is not None
- ] + [('__slots__', ())])
+ name,
+ (_LW,),
+ dict(
+ [
+ (field, _property_getters[idx])
+ for idx, field in enumerate(fields)
+ if field is not None
+ ]
+ + [("__slots__", ())]
+ ),
)
tp_cls._real_fields = fields
@@ -1077,6 +1094,7 @@ def has_dupes(sequence, target):
return True
return False
+
# .index version. the two __contains__ calls as well
# as .index() and isinstance() slow this down.
# def has_dupes(sequence, target):