diff options
| author | Raymond Hettinger <python@rcn.com> | 2003-12-06 16:23:06 +0000 | 
|---|---|---|
| committer | Raymond Hettinger <python@rcn.com> | 2003-12-06 16:23:06 +0000 | 
| commit | d25c1c635164daa5c300342ac99c0810fd9b575c (patch) | |
| tree | df412ba3ffaa8fee35e2e12f96aab0beecdaaec0 | |
| parent | b8d5f245b7077d869121835ed72656ac14962ef0 (diff) | |
| download | cpython-git-d25c1c635164daa5c300342ac99c0810fd9b575c.tar.gz | |
Implement itertools.groupby()
Original idea by Guido van Rossum.
Idea for skipable inner iterators by Raymond Hettinger.
Idea for argument order and identity function default by Alex Martelli.
Implementation by Hye-Shik Chang (with tweaks by Raymond Hettinger).
| -rw-r--r-- | Doc/lib/libitertools.tex | 60 | ||||
| -rw-r--r-- | Lib/test/test_itertools.py | 108 | ||||
| -rw-r--r-- | Misc/NEWS | 5 | ||||
| -rw-r--r-- | Modules/itertoolsmodule.c | 322 | 
4 files changed, 493 insertions, 2 deletions
| diff --git a/Doc/lib/libitertools.tex b/Doc/lib/libitertools.tex index 6f9f5c6259..82912b0501 100644 --- a/Doc/lib/libitertools.tex +++ b/Doc/lib/libitertools.tex @@ -130,6 +130,54 @@ by functions or loops that truncate the stream.    \end{verbatim}  \end{funcdesc} +\begin{funcdesc}{groupby}{iterable\optional{, key}} +  Make an iterator that returns consecutive keys and groups from the +  \var{iterable}.  \var{key} is function computing a key value for each +  element.  If not specified or is \code{None}, \var{key} defaults to an +  identity function   (returning the element unchanged).  Generally, the +  iterable needs to already be sorted on the same key function. + +  The returned group is itself an iterator that shares the underlying +  iterable with \function{groupby()}.  Because the source is shared, when +  the \function{groupby} object is advanced, the previous group is no +  longer visible.  So, if that data is needed later, it should be stored +  as a list: + +  \begin{verbatim} +    groups = [] +    uniquekeys = [] +    for k, g in groupby(data, keyfunc): +        groups.append(list(g))      # Store group iterator as a list +        uniquekeys.append(k) +  \end{verbatim} + +  \function{groupby()} is equivalent to: + +  \begin{verbatim} +    class groupby(object): +        def __init__(self, iterable, key=None): +            if key is None: +                key = lambda x: x +            self.keyfunc = key +            self.it = iter(iterable) +            self.tgtkey = self.currkey = self.currvalue = xrange(0) +        def __iter__(self): +            return self +        def next(self): +            while self.currkey == self.tgtkey: +                self.currvalue = self.it.next() # Exit on StopIteration +                self.currkey = self.keyfunc(self.currvalue) +            self.tgtkey = self.currkey +            return (self.currkey, self._grouper(self.tgtkey)) +        def _grouper(self, tgtkey): +            while self.currkey == tgtkey: +                yield self.currvalue +                self.currvalue = self.it.next() # Exit on StopIteration +                self.currkey = self.keyfunc(self.currvalue) +  \end{verbatim} +  \versionadded{2.4} +\end{funcdesc} +  \begin{funcdesc}{ifilter}{predicate, iterable}    Make an iterator that filters elements from iterable returning only    those for which the predicate is \code{True}. @@ -346,6 +394,18 @@ Martin  Walter  Samuele +# Show a dictionary sorted and grouped by value +>>> from operator import itemgetter +>>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3) +>>> di = list.sorted(d.iteritems(), key=itemgetter(1)) +>>> for k, g in groupby(di, key=itemgetter(1)): +...     print k, map(itemgetter(0), g) +... +1 ['a', 'c', 'e'] +2 ['b', 'd', 'f'] +3 ['g'] + +  \end{verbatim}  This section shows how itertools can be combined to create other more diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 543acc199c..b4c0a8bff2 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -61,6 +61,94 @@ class TestBasicOps(unittest.TestCase):          self.assertRaises(TypeError, cycle, 5)          self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0]) +    def test_groupby(self): +        # Check whether it accepts arguments correctly +        self.assertEqual([], list(groupby([]))) +        self.assertEqual([], list(groupby([], key=id))) +        self.assertRaises(TypeError, list, groupby('abc', [])) +        self.assertRaises(TypeError, groupby, None) + +        # Check normal input +        s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22), +             (2,15,22), (3,16,23), (3,17,23)] +        dup = [] +        for k, g in groupby(s, lambda r:r[0]): +            for elem in g: +                self.assertEqual(k, elem[0]) +                dup.append(elem) +        self.assertEqual(s, dup) + +        # Check nested case +        dup = [] +        for k, g in groupby(s, lambda r:r[0]): +            for ik, ig in groupby(g, lambda r:r[2]): +                for elem in ig: +                    self.assertEqual(k, elem[0]) +                    self.assertEqual(ik, elem[2]) +                    dup.append(elem) +        self.assertEqual(s, dup) + +        # Check case where inner iterator is not used +        keys = [k for k, g in groupby(s, lambda r:r[0])] +        expectedkeys = set([r[0] for r in s]) +        self.assertEqual(set(keys), expectedkeys) +        self.assertEqual(len(keys), len(expectedkeys)) + +        # Exercise pipes and filters style +        s = 'abracadabra' +        # sort s | uniq +        r = [k for k, g in groupby(list.sorted(s))] +        self.assertEqual(r, ['a', 'b', 'c', 'd', 'r']) +        # sort s | uniq -d +        r = [k for k, g in groupby(list.sorted(s)) if list(islice(g,1,2))] +        self.assertEqual(r, ['a', 'b', 'r']) +        # sort s | uniq -c +        r = [(len(list(g)), k) for k, g in groupby(list.sorted(s))] +        self.assertEqual(r, [(5, 'a'), (2, 'b'), (1, 'c'), (1, 'd'), (2, 'r')]) +        # sort s | uniq -c | sort -rn | head -3 +        r = list.sorted([(len(list(g)) , k) for k, g in groupby(list.sorted(s))], reverse=True)[:3] +        self.assertEqual(r, [(5, 'a'), (2, 'r'), (2, 'b')]) + +        # iter.next failure +        class ExpectedError(Exception): +            pass +        def delayed_raise(n=0): +            for i in range(n): +                yield 'yo' +            raise ExpectedError +        def gulp(iterable, keyp=None, func=list): +            return [func(g) for k, g in groupby(iterable, keyp)] + +        # iter.next failure on outer object +        self.assertRaises(ExpectedError, gulp, delayed_raise(0)) +        # iter.next failure on inner object +        self.assertRaises(ExpectedError, gulp, delayed_raise(1)) + +        # __cmp__ failure +        class DummyCmp: +            def __cmp__(self, dst): +                raise ExpectedError +        s = [DummyCmp(), DummyCmp(), None] + +        # __cmp__ failure on outer object +        self.assertRaises(ExpectedError, gulp, s, func=id) +        # __cmp__ failure on inner object +        self.assertRaises(ExpectedError, gulp, s) + +        # keyfunc failure +        def keyfunc(obj): +            if keyfunc.skip > 0: +                keyfunc.skip -= 1 +                return obj +            else: +                raise ExpectedError + +        # keyfunc failure on outer object +        keyfunc.skip = 0 +        self.assertRaises(ExpectedError, gulp, [None], keyfunc) +        keyfunc.skip = 1 +        self.assertRaises(ExpectedError, gulp, [None, None], keyfunc) +      def test_ifilter(self):          self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4])          self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2]) @@ -268,7 +356,7 @@ class TestBasicOps(unittest.TestCase):      def test_StopIteration(self):          self.assertRaises(StopIteration, izip().next) -        for f in (chain, cycle, izip): +        for f in (chain, cycle, izip, groupby):              self.assertRaises(StopIteration, f([]).next)              self.assertRaises(StopIteration, f(StopNow()).next) @@ -426,6 +514,14 @@ class TestVariousIteratorArgs(unittest.TestCase):              self.assertRaises(TypeError, list, cycle(N(s)))              self.assertRaises(ZeroDivisionError, list, cycle(E(s))) +    def test_groupby(self): +        for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)): +            for g in (G, I, Ig, S, L, R): +                self.assertEqual([k for k, sb in groupby(g(s))], list(g(s))) +            self.assertRaises(TypeError, groupby, X(s)) +            self.assertRaises(TypeError, list, groupby(N(s))) +            self.assertRaises(ZeroDivisionError, list, groupby(E(s))) +      def test_ifilter(self):          for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):              for g in (G, I, Ig, S, L, R): @@ -571,6 +667,16 @@ Martin  Walter  Samuele +>>> from operator import itemgetter +>>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3) +>>> di = list.sorted(d.iteritems(), key=itemgetter(1)) +>>> for k, g in groupby(di, itemgetter(1)): +...     print k, map(itemgetter(0), g) +... +1 ['a', 'c', 'e'] +2 ['b', 'd', 'f'] +3 ['g'] +  >>> def take(n, seq):  ...     return list(islice(seq, n)) @@ -164,6 +164,11 @@ Extension modules    SF bug #812202).  Generators that do not define genrandbits() now    issue a warning when randrange() is called with a range that large. +- itertools has a new function, groupby() for aggregating iterables +  into groups sharing the same key (as determined by a key function). +  It offers some of functionality of SQL's groupby keyword and of +  the Unix uniq filter. +                                                                    - itertools now has a new function, tee() which produces two independent    iterators from a single iterable. diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index a341a6630a..387133c474 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -7,6 +7,323 @@     All rights reserved.  */ + +/* groupby object ***********************************************************/ + +typedef struct { +	PyObject_HEAD +	PyObject *it; +	PyObject *keyfunc; +	PyObject *tgtkey; +	PyObject *currkey; +	PyObject *currvalue; +} groupbyobject; + +static PyTypeObject groupby_type; +static PyObject *_grouper_create(groupbyobject *, PyObject *); + +static PyObject * +groupby_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ +	static char *kwargs[] = {"iterable", "key", NULL}; +	groupbyobject *gbo; + 	PyObject *it, *keyfunc = Py_None; +  + 	if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:groupby", kwargs, +					 &it, &keyfunc)) +		return NULL; + +	gbo = (groupbyobject *)type->tp_alloc(type, 0); +	if (gbo == NULL) +		return NULL; +	gbo->tgtkey = NULL; +	gbo->currkey = NULL; +	gbo->currvalue = NULL; +	gbo->keyfunc = keyfunc; +	Py_INCREF(keyfunc); +	gbo->it = PyObject_GetIter(it); +	if (gbo->it == NULL) { +		Py_DECREF(gbo); +		return NULL; +	} +	return (PyObject *)gbo; +} + +static void +groupby_dealloc(groupbyobject *gbo) +{ +	PyObject_GC_UnTrack(gbo); +	Py_XDECREF(gbo->it); +	Py_XDECREF(gbo->keyfunc); +	Py_XDECREF(gbo->tgtkey); +	Py_XDECREF(gbo->currkey); +	Py_XDECREF(gbo->currvalue); +	gbo->ob_type->tp_free(gbo); +} + +static int +groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg) +{ +	int err; + +	if (gbo->it) { +		err = visit(gbo->it, arg); +		if (err) +			return err; +	} +	if (gbo->keyfunc) { +		err = visit(gbo->keyfunc, arg); +		if (err) +			return err; +	} +	if (gbo->tgtkey) { +		err = visit(gbo->tgtkey, arg); +		if (err) +			return err; +	} +	if (gbo->currkey) { +		err = visit(gbo->currkey, arg); +		if (err) +			return err; +	} +	if (gbo->currvalue) { +		err = visit(gbo->currvalue, arg); +		if (err) +			return err; +	} +	return 0; +} + +static PyObject * +groupby_next(groupbyobject *gbo) +{ +	PyObject *newvalue, *newkey, *r, *grouper; + +	/* skip to next iteration group */ +	for (;;) { +		if (gbo->currkey == NULL) +			/* pass */; +		else if (gbo->tgtkey == NULL) +			break; +		else { +			int rcmp; + +			rcmp = PyObject_RichCompareBool(gbo->tgtkey, +							gbo->currkey, Py_EQ); +			if (rcmp == -1) +				return NULL; +			else if (rcmp == 0) +				break; +		} + +		newvalue = PyIter_Next(gbo->it); +		if (newvalue == NULL) +			return NULL; + +		if (gbo->keyfunc == Py_None) { +			newkey = newvalue; +			Py_INCREF(newvalue); +		} else { +			newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc, +							      newvalue, NULL); +			if (newkey == NULL) { +				Py_DECREF(newvalue); +				return NULL; +			} +		} + +		Py_XDECREF(gbo->currkey); +		gbo->currkey = newkey; +		Py_XDECREF(gbo->currvalue); +		gbo->currvalue = newvalue; +	} + +	Py_XDECREF(gbo->tgtkey); +	gbo->tgtkey = gbo->currkey; +	Py_INCREF(gbo->currkey); + +	grouper = _grouper_create(gbo, gbo->tgtkey); +	if (grouper == NULL) +		return NULL; + +	r = PyTuple_Pack(2, gbo->currkey, grouper); +	Py_DECREF(grouper); +	return r; +} + +PyDoc_STRVAR(groupby_doc, +"groupby(iterable[, keyfunc]) -> create an iterator which returns\n\ +(key, sub-iterator) grouped by each value of key(value).\n"); + +static PyTypeObject groupby_type = { +	PyObject_HEAD_INIT(NULL) +	0,				/* ob_size */ +	"itertools.groupby",		/* tp_name */ +	sizeof(groupbyobject),		/* tp_basicsize */ +	0,				/* tp_itemsize */ +	/* methods */ +	(destructor)groupby_dealloc,	/* tp_dealloc */ +	0,				/* tp_print */ +	0,				/* tp_getattr */ +	0,				/* tp_setattr */ +	0,				/* tp_compare */ +	0,				/* tp_repr */ +	0,				/* tp_as_number */ +	0,				/* tp_as_sequence */ +	0,				/* tp_as_mapping */ +	0,				/* tp_hash */ +	0,				/* tp_call */ +	0,				/* tp_str */ +	PyObject_GenericGetAttr,	/* tp_getattro */ +	0,				/* tp_setattro */ +	0,				/* tp_as_buffer */ +	Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | +		Py_TPFLAGS_BASETYPE,	/* tp_flags */ +	groupby_doc,			/* tp_doc */ +	(traverseproc)groupby_traverse,	/* tp_traverse */ +	0,				/* tp_clear */ +	0,				/* tp_richcompare */ +	0,				/* tp_weaklistoffset */ +	PyObject_SelfIter,		/* tp_iter */ +	(iternextfunc)groupby_next,	/* tp_iternext */ +	0,				/* tp_methods */ +	0,				/* tp_members */ +	0,				/* tp_getset */ +	0,				/* tp_base */ +	0,				/* tp_dict */ +	0,				/* tp_descr_get */ +	0,				/* tp_descr_set */ +	0,				/* tp_dictoffset */ +	0,				/* tp_init */ +	0,				/* tp_alloc */ +	groupby_new,			/* tp_new */ +	PyObject_GC_Del,		/* tp_free */ +}; + + +/* _grouper object (internal) ************************************************/ + +typedef struct { +	PyObject_HEAD +	PyObject *parent; +	PyObject *tgtkey; +} _grouperobject; + +static PyTypeObject _grouper_type; + +static PyObject * +_grouper_create(groupbyobject *parent, PyObject *tgtkey) +{ +	_grouperobject *igo; + +	igo = PyObject_New(_grouperobject, &_grouper_type); +	if (igo == NULL) +		return NULL; +	igo->parent = (PyObject *)parent; +	Py_INCREF(parent); +	igo->tgtkey = tgtkey; +	Py_INCREF(tgtkey); + +	return (PyObject *)igo; +} + +static void +_grouper_dealloc(_grouperobject *igo) +{ +	Py_DECREF(igo->parent); +	Py_DECREF(igo->tgtkey); +	PyObject_Del(igo); +} + +static PyObject * +_grouper_next(_grouperobject *igo) +{ +	groupbyobject *gbo = (groupbyobject *)igo->parent; +	PyObject *newvalue, *newkey, *r; +	int rcmp; + +	if (gbo->currvalue == NULL) { +		newvalue = PyIter_Next(gbo->it); +		if (newvalue == NULL) +			return NULL; + +		if (gbo->keyfunc == Py_None) { +			newkey = newvalue; +			Py_INCREF(newvalue); +		} else { +			newkey = PyObject_CallFunctionObjArgs(gbo->keyfunc, +							      newvalue, NULL); +			if (newkey == NULL) { +				Py_DECREF(newvalue); +				return NULL; +			} +		} + +		assert(gbo->currkey == NULL); +		gbo->currkey = newkey; +		gbo->currvalue = newvalue; +	} + +	assert(gbo->currkey != NULL); +	rcmp = PyObject_RichCompareBool(igo->tgtkey, gbo->currkey, Py_EQ); +	if (rcmp <= 0) +		/* got any error or current group is end */ +		return NULL; + +	r = gbo->currvalue; +	gbo->currvalue = NULL; +	Py_DECREF(gbo->currkey); +	gbo->currkey = NULL; + +	return r; +} + +static PyTypeObject _grouper_type = { +	PyObject_HEAD_INIT(NULL) +	0,				/* ob_size */ +	"itertools._grouper",		/* tp_name */ +	sizeof(_grouperobject),		/* tp_basicsize */ +	0,				/* tp_itemsize */ +	/* methods */ +	(destructor)_grouper_dealloc,	/* tp_dealloc */ +	0,				/* tp_print */ +	0,				/* tp_getattr */ +	0,				/* tp_setattr */ +	0,				/* tp_compare */ +	0,				/* tp_repr */ +	0,				/* tp_as_number */ +	0,				/* tp_as_sequence */ +	0,				/* tp_as_mapping */ +	0,				/* tp_hash */ +	0,				/* tp_call */ +	0,				/* tp_str */ +	PyObject_GenericGetAttr,	/* tp_getattro */ +	0,				/* tp_setattro */ +	0,				/* tp_as_buffer */ +	Py_TPFLAGS_DEFAULT,		/* tp_flags */ +	0,				/* tp_doc */ +	0, 				/* tp_traverse */ +	0,				/* tp_clear */ +	0,				/* tp_richcompare */ +	0,				/* tp_weaklistoffset */ +	PyObject_SelfIter,		/* tp_iter */ +	(iternextfunc)_grouper_next,	/* tp_iternext */ +	0,				/* tp_methods */ +	0,				/* tp_members */ +	0,				/* tp_getset */ +	0,				/* tp_base */ +	0,				/* tp_dict */ +	0,				/* tp_descr_get */ +	0,				/* tp_descr_set */ +	0,				/* tp_dictoffset */ +	0,				/* tp_init */ +	0,				/* tp_alloc */ +	0,				/* tp_new */ +	PyObject_Del,			/* tp_free */ +}; + +  +  /* tee object and with supporting function and objects ***************/  /* The teedataobject pre-allocates space for LINKCELLS number of objects. @@ -2103,6 +2420,7 @@ tee(it, n=2) --> (it1, it2 , ... itn) splits one iterator into n\n\  chain(p, q, ...) --> p0, p1, ... plast, q0, q1, ... \n\  takewhile(pred, seq) --> seq[0], seq[1], until pred fails\n\  dropwhile(pred, seq) --> seq[n], seq[n+1], starting when pred fails\n\ +groupby(iterable[, keyfunc]) --> sub-iterators grouped by value of keyfunc(v)\n\  "); @@ -2130,6 +2448,7 @@ inititertools(void)  		&count_type,  		&izip_type,  		&repeat_type, +		&groupby_type,  		NULL  	}; @@ -2148,5 +2467,6 @@ inititertools(void)  		return;  	if (PyType_Ready(&tee_type) < 0)  		return; - +	if (PyType_Ready(&_grouper_type) < 0) +		return;  } | 
