diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-09-23 21:19:23 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-09-23 21:24:16 +0200 |
commit | 8bf9a18f68a36f81bbd27ce52af65ca3cfd217fd (patch) | |
tree | 871086e88c8ac7670ff1606c3b028bd3845d2c0f | |
parent | 7be6655c34c2e25232964079818fffe0d7d03696 (diff) | |
download | numpy-8bf9a18f68a36f81bbd27ce52af65ca3cfd217fd.tar.gz |
BUG: check if object provides len() before trying to iterate it
some libraries want object arrays from objects that are iterable but
rely on not providing len() to get the right dtype from numpy.
closes gh-5100
-rw-r--r-- | numpy/core/src/multiarray/common.c | 10 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 14 |
2 files changed, 23 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/common.c b/numpy/core/src/multiarray/common.c index 2b3d3c3d2..2f9a58c51 100644 --- a/numpy/core/src/multiarray/common.c +++ b/numpy/core/src/multiarray/common.c @@ -518,12 +518,20 @@ PyArray_DTypeFromObjectHelper(PyObject *obj, int maxdims, return 0; } + /* + * fails if convertable to list but no len is defined which some libraries + * require to get object arrays + */ + size = PySequence_Size(obj); + if (size < 0) { + goto fail; + } + /* Recursive case, first check the sequence contains only one type */ seq = PySequence_Fast(obj, "Could not convert object to sequence"); if (seq == NULL) { goto fail; } - size = PySequence_Fast_GET_SIZE(seq); objects = PySequence_Fast_ITEMS(seq); common_type = size > 0 ? Py_TYPE(objects[0]) : NULL; for (i = 1; i < size; ++i) { diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index d02821cba..523d64848 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -593,6 +593,20 @@ class TestCreation(TestCase): assert_(a.dtype == np.dtype(object)) assert_raises(ValueError, np.array, [Fail()]) + def test_no_len_object_type(self): + # gh-5100, want object array from iterable object without len() + class Point2: + def __init__(self): + pass + + def __getitem__(self, ind): + if ind in [0, 1]: + return ind + else: + raise IndexError() + d = np.array([Point2(), Point2(), Point2()]) + assert_equal(d.dtype, np.dtype(object)) + class TestStructured(TestCase): def test_subarray_field_access(self): |