diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/shape.c | 17 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 6 |
2 files changed, 18 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c index 3ac71e285..30820737e 100644 --- a/numpy/core/src/multiarray/shape.c +++ b/numpy/core/src/multiarray/shape.c @@ -89,11 +89,19 @@ PyArray_Resize(PyArrayObject *self, PyArray_Dims *newshape, int refcheck, return NULL; } + if (PyArray_BASE(self) != NULL + || (((PyArrayObject_fields *)self)->weakreflist != NULL)) { + PyErr_SetString(PyExc_ValueError, + "cannot resize an array that " + "references or is referenced\n" + "by another array in this way. Use the np.resize function."); + return NULL; + } if (refcheck) { #ifdef PYPY_VERSION PyErr_SetString(PyExc_ValueError, "cannot resize an array with refcheck=True on PyPy.\n" - "Use the resize function or refcheck=False"); + "Use the np.resize function or refcheck=False"); return NULL; #else refcnt = PyArray_REFCOUNT(self); @@ -102,13 +110,12 @@ PyArray_Resize(PyArrayObject *self, PyArray_Dims *newshape, int refcheck, else { refcnt = 1; } - if ((refcnt > 2) - || (PyArray_BASE(self) != NULL) - || (((PyArrayObject_fields *)self)->weakreflist != NULL)) { + if (refcnt > 2) { PyErr_SetString(PyExc_ValueError, "cannot resize an array that " "references or is referenced\n" - "by another array in this way. Use the resize function"); + "by another array in this way.\n" + "Use the np.resize function or refcheck=False"); return NULL; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 8cd0f4d92..4b2a38990 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4829,6 +4829,12 @@ class TestResize(object): x_view.resize((0, 10)) x_view.resize((0, 100)) + def test_check_weakref(self): + x = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + xref = weakref.ref(x) + assert_raises(ValueError, x.resize, (5, 1)) + del xref # avoid pyflakes unused variable warning. + class TestRecord(object): def test_field_rename(self): |