summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-06-08 11:44:17 -0400
committerGitHub <noreply@github.com>2018-06-08 11:44:17 -0400
commitccb2a712ab55a34b26b06bfd325020af3889b00c (patch)
tree9eb59ad9d313c9c1490e19682280d1820028902a /numpy/core
parentda5eaf97281453083252c22b2b94aded846a936b (diff)
parent95c1e49c64ebb7a60dae7355a9f37d1cefef9cc6 (diff)
downloadnumpy-ccb2a712ab55a34b26b06bfd325020af3889b00c.tar.gz
Merge pull request #11155 from eric-wieser/datetime-stack-overflow
BUG: Prevent stackoverflow in conversion to datetime types
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/include/numpy/npy_3kcompat.h8
-rw-r--r--numpy/core/src/multiarray/datetime.c24
-rw-r--r--numpy/core/src/multiarray/number.c13
-rw-r--r--numpy/core/tests/test_datetime.py17
4 files changed, 41 insertions, 21 deletions
diff --git a/numpy/core/include/numpy/npy_3kcompat.h b/numpy/core/include/numpy/npy_3kcompat.h
index 56fbd99af..2d0ccd3b9 100644
--- a/numpy/core/include/numpy/npy_3kcompat.h
+++ b/numpy/core/include/numpy/npy_3kcompat.h
@@ -61,6 +61,14 @@ static NPY_INLINE int PyInt_Check(PyObject *op) {
PySlice_GetIndicesEx((PySliceObject *)op, nop, start, end, step, slicelength)
#endif
+/* <2.7.11 and <3.4.4 have the wrong argument type for Py_EnterRecursiveCall */
+#if (PY_VERSION_HEX < 0x02070B00) || \
+ ((0x03000000 <= PY_VERSION_HEX) && (PY_VERSION_HEX < 0x03040400))
+ #define Npy_EnterRecursiveCall(x) Py_EnterRecursiveCall((char *)(x))
+#else
+ #define Npy_EnterRecursiveCall(x) Py_EnterRecursiveCall(x)
+#endif
+
/*
* PyString -> PyBytes
*/
diff --git a/numpy/core/src/multiarray/datetime.c b/numpy/core/src/multiarray/datetime.c
index af542aecc..df9b9cec4 100644
--- a/numpy/core/src/multiarray/datetime.c
+++ b/numpy/core/src/multiarray/datetime.c
@@ -3721,19 +3721,21 @@ recursive_find_object_datetime64_type(PyObject *obj,
}
for (i = 0; i < len; ++i) {
+ int ret;
PyObject *f = PySequence_GetItem(obj, i);
if (f == NULL) {
return -1;
}
- if (f == obj) {
- Py_DECREF(f);
- return 0;
- }
- if (recursive_find_object_datetime64_type(f, meta) < 0) {
+ if (Npy_EnterRecursiveCall(" in recursive_find_object_datetime64_type") != 0) {
Py_DECREF(f);
return -1;
}
+ ret = recursive_find_object_datetime64_type(f, meta);
+ Py_LeaveRecursiveCall();
Py_DECREF(f);
+ if (ret < 0) {
+ return ret;
+ }
}
return 0;
@@ -3823,19 +3825,21 @@ recursive_find_object_timedelta64_type(PyObject *obj,
}
for (i = 0; i < len; ++i) {
+ int ret;
PyObject *f = PySequence_GetItem(obj, i);
if (f == NULL) {
return -1;
}
- if (f == obj) {
- Py_DECREF(f);
- return 0;
- }
- if (recursive_find_object_timedelta64_type(f, meta) < 0) {
+ if (Npy_EnterRecursiveCall(" in recursive_find_object_timedelta64_type") != 0) {
Py_DECREF(f);
return -1;
}
+ ret = recursive_find_object_timedelta64_type(f, meta);
+ Py_LeaveRecursiveCall();
Py_DECREF(f);
+ if (ret < 0) {
+ return ret;
+ }
}
return 0;
diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c
index 14389a925..448d2d9c2 100644
--- a/numpy/core/src/multiarray/number.c
+++ b/numpy/core/src/multiarray/number.c
@@ -16,15 +16,6 @@
#include "binop_override.h"
-/* <2.7.11 and <3.4.4 have the wrong argument type for Py_EnterRecursiveCall */
-#if (PY_VERSION_HEX < 0x02070B00) || \
- ((0x03000000 <= PY_VERSION_HEX) && (PY_VERSION_HEX < 0x03040400))
- #define _Py_EnterRecursiveCall(x) Py_EnterRecursiveCall((char *)(x))
-#else
- #define _Py_EnterRecursiveCall(x) Py_EnterRecursiveCall(x)
-#endif
-
-
/*************************************************************************
**************** Implement Number Protocol ****************************
*************************************************************************/
@@ -796,7 +787,7 @@ _array_nonzero(PyArrayObject *mp)
n = PyArray_SIZE(mp);
if (n == 1) {
int res;
- if (_Py_EnterRecursiveCall(" while converting array to bool")) {
+ if (Npy_EnterRecursiveCall(" while converting array to bool")) {
return -1;
}
res = PyArray_DESCR(mp)->f->nonzero(PyArray_DATA(mp), mp);
@@ -850,7 +841,7 @@ array_scalar_forward(PyArrayObject *v,
/* Need to guard against recursion if our array holds references */
if (PyDataType_REFCHK(PyArray_DESCR(v))) {
PyObject *res;
- if (_Py_EnterRecursiveCall(where) != 0) {
+ if (Npy_EnterRecursiveCall(where) != 0) {
Py_DECREF(scalar);
return NULL;
}
diff --git a/numpy/core/tests/test_datetime.py b/numpy/core/tests/test_datetime.py
index e433877e8..942554cae 100644
--- a/numpy/core/tests/test_datetime.py
+++ b/numpy/core/tests/test_datetime.py
@@ -17,6 +17,11 @@ try:
except ImportError:
_has_pytz = False
+try:
+ RecursionError
+except NameError:
+ RecursionError = RuntimeError # python < 3.5
+
class TestDateTime(object):
def test_datetime_dtype_creation(self):
@@ -1997,6 +2002,18 @@ class TestDateTime(object):
continue
assert_raises(TypeError, np.isnat, np.zeros(10, t))
+ def test_corecursive_input(self):
+ # construct a co-recursive list
+ a, b = [], []
+ a.append(b)
+ b.append(a)
+ obj_arr = np.array([None])
+ obj_arr[0] = a
+
+ # gh-11154: This shouldn't cause a C stack overflow
+ assert_raises(RecursionError, obj_arr.astype, 'M8')
+ assert_raises(RecursionError, obj_arr.astype, 'm8')
+
class TestDateTimeData(object):