summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-09-23 21:19:23 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-09-23 21:24:16 +0200
commit8bf9a18f68a36f81bbd27ce52af65ca3cfd217fd (patch)
tree871086e88c8ac7670ff1606c3b028bd3845d2c0f
parent7be6655c34c2e25232964079818fffe0d7d03696 (diff)
downloadnumpy-8bf9a18f68a36f81bbd27ce52af65ca3cfd217fd.tar.gz
BUG: check if object provides len() before trying to iterate it
some libraries want object arrays from objects that are iterable but rely on not providing len() to get the right dtype from numpy. closes gh-5100
-rw-r--r--numpy/core/src/multiarray/common.c10
-rw-r--r--numpy/core/tests/test_multiarray.py14
2 files changed, 23 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c
index 2b3d3c3d2..2f9a58c51 100644
--- a/numpy/core/src/multiarray/common.c
+++ b/numpy/core/src/multiarray/common.c
@@ -518,12 +518,20 @@ PyArray_DTypeFromObjectHelper(PyObject *obj, int maxdims,
return 0;
}
+ /*
+ * fails if convertable to list but no len is defined which some libraries
+ * require to get object arrays
+ */
+ size = PySequence_Size(obj);
+ if (size < 0) {
+ goto fail;
+ }
+
/* Recursive case, first check the sequence contains only one type */
seq = PySequence_Fast(obj, "Could not convert object to sequence");
if (seq == NULL) {
goto fail;
}
- size = PySequence_Fast_GET_SIZE(seq);
objects = PySequence_Fast_ITEMS(seq);
common_type = size > 0 ? Py_TYPE(objects[0]) : NULL;
for (i = 1; i < size; ++i) {
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index d02821cba..523d64848 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -593,6 +593,20 @@ class TestCreation(TestCase):
assert_(a.dtype == np.dtype(object))
assert_raises(ValueError, np.array, [Fail()])
+ def test_no_len_object_type(self):
+ # gh-5100, want object array from iterable object without len()
+ class Point2:
+ def __init__(self):
+ pass
+
+ def __getitem__(self, ind):
+ if ind in [0, 1]:
+ return ind
+ else:
+ raise IndexError()
+ d = np.array([Point2(), Point2(), Point2()])
+ assert_equal(d.dtype, np.dtype(object))
+
class TestStructured(TestCase):
def test_subarray_field_access(self):