summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2009-06-19 13:01:20 +0000
committerPauli Virtanen <pav@iki.fi>2009-06-19 13:01:20 +0000
commit6cfb1c6c4ef9b5d44ada70d94f1573c8f964c3ac (patch)
tree7f6b41f4fa86a8809144b2c64f3c7a68f705950b
parentee57730721f7b101cb9477be50a3d1bb255ebf06 (diff)
downloadnumpy-6cfb1c6c4ef9b5d44ada70d94f1573c8f964c3ac.tar.gz
Fixed #1140: avoid div-by-zero in iter_coords_get for size=0 arrays
-rw-r--r--numpy/core/src/multiarray/iterators.c8
-rw-r--r--numpy/lib/tests/test_index_tricks.py8
-rw-r--r--numpy/lib/tests/test_regression.py4
3 files changed, 17 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c
index 842522d73..e447fe07c 100644
--- a/numpy/core/src/multiarray/iterators.c
+++ b/numpy/core/src/multiarray/iterators.c
@@ -1195,8 +1195,12 @@ iter_coords_get(PyArrayIterObject *self)
int i;
val = self->index;
for (i = 0; i < nd; i++) {
- self->coordinates[i] = val / self->factors[i];
- val = val % self->factors[i];
+ if (self->factors[i] != 0) {
+ self->coordinates[i] = val / self->factors[i];
+ val = val % self->factors[i];
+ } else {
+ self->coordinates[i] = 0;
+ }
}
}
return PyArray_IntTupleFromIntp(nd, self->coordinates);
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 47529502d..641737d43 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -1,5 +1,5 @@
from numpy.testing import *
-from numpy import array, ones, r_, mgrid, unravel_index
+from numpy import array, ones, r_, mgrid, unravel_index, ndenumerate
class TestUnravelIndex(TestCase):
def test_basic(self):
@@ -62,5 +62,11 @@ class TestConcatenator(TestCase):
assert_array_equal(d[5:,:],c)
+class TestNdenumerate(TestCase):
+ def test_basic(self):
+ a = array([[1,2], [3,4]])
+ assert_equal(list(ndenumerate(a)),
+ [((0,0), 1), ((0,1), 2), ((1,0), 3), ((1,1), 4)])
+
if __name__ == "__main__":
run_module_suite()
diff --git a/numpy/lib/tests/test_regression.py b/numpy/lib/tests/test_regression.py
index b8c487962..5abf9aefe 100644
--- a/numpy/lib/tests/test_regression.py
+++ b/numpy/lib/tests/test_regression.py
@@ -48,6 +48,10 @@ class TestRegression(object):
"""Ticket 928."""
assert_raises(ValueError, np.histogramdd, np.ones((1,10)), bins=2**10)
+ def test_ndenumerate_crash(self):
+ """Ticket 1140"""
+ # Shouldn't crash:
+ list(np.ndenumerate(np.array([[]])))
if __name__ == "__main__":
run_module_suite()