diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-06-08 11:44:17 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-08 11:44:17 -0400 |
commit | ccb2a712ab55a34b26b06bfd325020af3889b00c (patch) | |
tree | 9eb59ad9d313c9c1490e19682280d1820028902a /numpy/core | |
parent | da5eaf97281453083252c22b2b94aded846a936b (diff) | |
parent | 95c1e49c64ebb7a60dae7355a9f37d1cefef9cc6 (diff) | |
download | numpy-ccb2a712ab55a34b26b06bfd325020af3889b00c.tar.gz |
Merge pull request #11155 from eric-wieser/datetime-stack-overflow
BUG: Prevent stackoverflow in conversion to datetime types
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/include/numpy/npy_3kcompat.h | 8 | ||||
-rw-r--r-- | numpy/core/src/multiarray/datetime.c | 24 | ||||
-rw-r--r-- | numpy/core/src/multiarray/number.c | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_datetime.py | 17 |
4 files changed, 41 insertions, 21 deletions
diff --git a/numpy/core/include/numpy/npy_3kcompat.h b/numpy/core/include/numpy/npy_3kcompat.h index 56fbd99af..2d0ccd3b9 100644 --- a/numpy/core/include/numpy/npy_3kcompat.h +++ b/numpy/core/include/numpy/npy_3kcompat.h @@ -61,6 +61,14 @@ static NPY_INLINE int PyInt_Check(PyObject *op) { PySlice_GetIndicesEx((PySliceObject *)op, nop, start, end, step, slicelength) #endif +/* <2.7.11 and <3.4.4 have the wrong argument type for Py_EnterRecursiveCall */ +#if (PY_VERSION_HEX < 0x02070B00) || \ + ((0x03000000 <= PY_VERSION_HEX) && (PY_VERSION_HEX < 0x03040400)) + #define Npy_EnterRecursiveCall(x) Py_EnterRecursiveCall((char *)(x)) +#else + #define Npy_EnterRecursiveCall(x) Py_EnterRecursiveCall(x) +#endif + /* * PyString -> PyBytes */ diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c index af542aecc..df9b9cec4 100644 --- a/numpy/core/src/multiarray/datetime.c +++ b/numpy/core/src/multiarray/datetime.c @@ -3721,19 +3721,21 @@ recursive_find_object_datetime64_type(PyObject *obj, } for (i = 0; i < len; ++i) { + int ret; PyObject *f = PySequence_GetItem(obj, i); if (f == NULL) { return -1; } - if (f == obj) { - Py_DECREF(f); - return 0; - } - if (recursive_find_object_datetime64_type(f, meta) < 0) { + if (Npy_EnterRecursiveCall(" in recursive_find_object_datetime64_type") != 0) { Py_DECREF(f); return -1; } + ret = recursive_find_object_datetime64_type(f, meta); + Py_LeaveRecursiveCall(); Py_DECREF(f); + if (ret < 0) { + return ret; + } } return 0; @@ -3823,19 +3825,21 @@ recursive_find_object_timedelta64_type(PyObject *obj, } for (i = 0; i < len; ++i) { + int ret; PyObject *f = PySequence_GetItem(obj, i); if (f == NULL) { return -1; } - if (f == obj) { - Py_DECREF(f); - return 0; - } - if (recursive_find_object_timedelta64_type(f, meta) < 0) { + if (Npy_EnterRecursiveCall(" in recursive_find_object_timedelta64_type") != 0) { Py_DECREF(f); return -1; } + ret = recursive_find_object_timedelta64_type(f, meta); + Py_LeaveRecursiveCall(); Py_DECREF(f); + if (ret < 0) { + return ret; + } } return 0; diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 14389a925..448d2d9c2 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -16,15 +16,6 @@ #include "binop_override.h" -/* <2.7.11 and <3.4.4 have the wrong argument type for Py_EnterRecursiveCall */ -#if (PY_VERSION_HEX < 0x02070B00) || \ - ((0x03000000 <= PY_VERSION_HEX) && (PY_VERSION_HEX < 0x03040400)) - #define _Py_EnterRecursiveCall(x) Py_EnterRecursiveCall((char *)(x)) -#else - #define _Py_EnterRecursiveCall(x) Py_EnterRecursiveCall(x) -#endif - - /************************************************************************* **************** Implement Number Protocol **************************** *************************************************************************/ @@ -796,7 +787,7 @@ _array_nonzero(PyArrayObject *mp) n = PyArray_SIZE(mp); if (n == 1) { int res; - if (_Py_EnterRecursiveCall(" while converting array to bool")) { + if (Npy_EnterRecursiveCall(" while converting array to bool")) { return -1; } res = PyArray_DESCR(mp)->f->nonzero(PyArray_DATA(mp), mp); @@ -850,7 +841,7 @@ array_scalar_forward(PyArrayObject *v, /* Need to guard against recursion if our array holds references */ if (PyDataType_REFCHK(PyArray_DESCR(v))) { PyObject *res; - if (_Py_EnterRecursiveCall(where) != 0) { + if (Npy_EnterRecursiveCall(where) != 0) { Py_DECREF(scalar); return NULL; } diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py index e433877e8..942554cae 100644 --- a/numpy/core/tests/test_datetime.py +++ b/numpy/core/tests/test_datetime.py @@ -17,6 +17,11 @@ try: except ImportError: _has_pytz = False +try: + RecursionError +except NameError: + RecursionError = RuntimeError # python < 3.5 + class TestDateTime(object): def test_datetime_dtype_creation(self): @@ -1997,6 +2002,18 @@ class TestDateTime(object): continue assert_raises(TypeError, np.isnat, np.zeros(10, t)) + def test_corecursive_input(self): + # construct a co-recursive list + a, b = [], [] + a.append(b) + b.append(a) + obj_arr = np.array([None]) + obj_arr[0] = a + + # gh-11154: This shouldn't cause a C stack overflow + assert_raises(RecursionError, obj_arr.astype, 'M8') + assert_raises(RecursionError, obj_arr.astype, 'm8') + class TestDateTimeData(object): |