summaryrefslogtreecommitdiff
path: root/Lib/heapq.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/heapq.py')
-rw-r--r--Lib/heapq.py75
1 files changed, 62 insertions, 13 deletions
diff --git a/Lib/heapq.py b/Lib/heapq.py
index 48f804a773..74f7310a2c 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -129,10 +129,15 @@ From all times, sorting has always been a Great Art! :-)
__all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge',
'nlargest', 'nsmallest', 'heappushpop']
-from itertools import islice, repeat, count, imap, izip, tee
-from operator import itemgetter, neg
+from itertools import islice, repeat, count, imap, izip, tee, chain
+from operator import itemgetter
import bisect
+def cmp_lt(x, y):
+ # Use __lt__ if available; otherwise, try __le__.
+ # In Py3.x, only __lt__ will be called.
+ return (x < y) if hasattr(x, '__lt__') else (not y <= x)
+
def heappush(heap, item):
"""Push item onto heap, maintaining the heap invariant."""
heap.append(item)
@@ -167,7 +172,7 @@ def heapreplace(heap, item):
def heappushpop(heap, item):
"""Fast version of a heappush followed by a heappop."""
- if heap and heap[0] < item:
+ if heap and cmp_lt(heap[0], item):
item, heap[0] = heap[0], item
_siftup(heap, 0)
return item
@@ -215,11 +220,10 @@ def nsmallest(n, iterable):
pop = result.pop
los = result[-1] # los --> Largest of the nsmallest
for elem in it:
- if los <= elem:
- continue
- insort(result, elem)
- pop()
- los = result[-1]
+ if cmp_lt(elem, los):
+ insort(result, elem)
+ pop()
+ los = result[-1]
return result
# An alternative approach manifests the whole iterable in memory but
# saves comparisons by heapifying all at once. Also, saves time
@@ -240,7 +244,7 @@ def _siftdown(heap, startpos, pos):
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = heap[parentpos]
- if newitem < parent:
+ if cmp_lt(newitem, parent):
heap[pos] = parent
pos = parentpos
continue
@@ -295,7 +299,7 @@ def _siftup(heap, pos):
while childpos < endpos:
# Set childpos to index of smaller child.
rightpos = childpos + 1
- if rightpos < endpos and not heap[childpos] < heap[rightpos]:
+ if rightpos < endpos and not cmp_lt(heap[childpos], heap[rightpos]):
childpos = rightpos
# Move the smaller child up.
heap[pos] = heap[childpos]
@@ -308,7 +312,7 @@ def _siftup(heap, pos):
# If available, use C implementation
try:
- from _heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest, heappushpop
+ from _heapq import *
except ImportError:
pass
@@ -354,10 +358,32 @@ def nsmallest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key)[:n]
"""
+ # Short-cut for n==1 is to use min() when len(iterable)>0
+ if n == 1:
+ it = iter(iterable)
+ head = list(islice(it, 1))
+ if not head:
+ return []
+ if key is None:
+ return [min(chain(head, it))]
+ return [min(chain(head, it), key=key)]
+
+ # When n>=size, it's faster to use sort()
+ try:
+ size = len(iterable)
+ except (TypeError, AttributeError):
+ pass
+ else:
+ if n >= size:
+ return sorted(iterable, key=key)[:n]
+
+ # When key is none, use simpler decoration
if key is None:
it = izip(iterable, count()) # decorate
result = _nsmallest(n, it)
return map(itemgetter(0), result) # undecorate
+
+ # General case, slowest method
in1, in2 = tee(iterable)
it = izip(imap(key, in1), count(), in2) # decorate
result = _nsmallest(n, it)
@@ -369,12 +395,35 @@ def nlargest(n, iterable, key=None):
Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
"""
+
+ # Short-cut for n==1 is to use max() when len(iterable)>0
+ if n == 1:
+ it = iter(iterable)
+ head = list(islice(it, 1))
+ if not head:
+ return []
+ if key is None:
+ return [max(chain(head, it))]
+ return [max(chain(head, it), key=key)]
+
+ # When n>=size, it's faster to use sort()
+ try:
+ size = len(iterable)
+ except (TypeError, AttributeError):
+ pass
+ else:
+ if n >= size:
+ return sorted(iterable, key=key, reverse=True)[:n]
+
+ # When key is none, use simpler decoration
if key is None:
- it = izip(iterable, imap(neg, count())) # decorate
+ it = izip(iterable, count(0,-1)) # decorate
result = _nlargest(n, it)
return map(itemgetter(0), result) # undecorate
+
+ # General case, slowest method
in1, in2 = tee(iterable)
- it = izip(imap(key, in1), imap(neg, count()), in2) # decorate
+ it = izip(imap(key, in1), count(0,-1), in2) # decorate
result = _nlargest(n, it)
return map(itemgetter(2), result) # undecorate