diff options
| author | mattip <matti.picus@gmail.com> | 2018-03-29 13:42:40 +0300 |
|---|---|---|
| committer | mattip <matti.picus@gmail.com> | 2018-04-21 23:53:44 +0300 |
| commit | 05d94b9f59f2ca8e9dbc82fd01ac31a6b6aa34d7 (patch) | |
| tree | e1097b6cfd346abe13aff3f2d0064290467e153d /numpy/core | |
| parent | e0b5e8740efe6d42c909c1374494e614592c65ab (diff) | |
| download | numpy-05d94b9f59f2ca8e9dbc82fd01ac31a6b6aa34d7.tar.gz | |
BUG: test, fix PyArray_DiscardWritebackIfCopy refcount issue and document
Diffstat (limited to 'numpy/core')
| -rw-r--r-- | numpy/core/include/numpy/ndarrayobject.h | 9 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/_multiarray_tests.c.src | 15 | ||||
| -rw-r--r-- | numpy/core/tests/test_multiarray.py | 23 |
3 files changed, 45 insertions, 2 deletions
diff --git a/numpy/core/include/numpy/ndarrayobject.h b/numpy/core/include/numpy/ndarrayobject.h index ec0fd1ee9..97e41b6f3 100644 --- a/numpy/core/include/numpy/ndarrayobject.h +++ b/numpy/core/include/numpy/ndarrayobject.h @@ -170,14 +170,19 @@ extern "C" CONFUSE_EMACS (k)*PyArray_STRIDES(obj)[2] + \ (l)*PyArray_STRIDES(obj)[3])) +/* Move to arrayobject.c once PyArray_XDECREF_ERR is removed */ static NPY_INLINE void PyArray_DiscardWritebackIfCopy(PyArrayObject *arr) { if (arr != NULL) { + PyArrayObject_fields *fa = (PyArrayObject_fields *)arr; if ((PyArray_FLAGS(arr) & NPY_ARRAY_WRITEBACKIFCOPY) || (PyArray_FLAGS(arr) & NPY_ARRAY_UPDATEIFCOPY)) { - PyArrayObject *base = (PyArrayObject *)PyArray_BASE(arr); - PyArray_ENABLEFLAGS(base, NPY_ARRAY_WRITEABLE); + if (fa->base) { + PyArray_ENABLEFLAGS((PyArrayObject*)fa->base, NPY_ARRAY_WRITEABLE); + Py_DECREF(fa->base); + fa->base = NULL; + } PyArray_CLEARFLAGS(arr, NPY_ARRAY_WRITEBACKIFCOPY); PyArray_CLEARFLAGS(arr, NPY_ARRAY_UPDATEIFCOPY); } diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src index 38698887a..0299f1a1b 100644 --- a/numpy/core/src/multiarray/_multiarray_tests.c.src +++ b/numpy/core/src/multiarray/_multiarray_tests.c.src @@ -687,6 +687,18 @@ npy_resolve(PyObject* NPY_UNUSED(self), PyObject* args) Py_RETURN_NONE; } +/* resolve WRITEBACKIFCOPY */ +static PyObject* +npy_discard(PyObject* NPY_UNUSED(self), PyObject* args) +{ + if (!PyArray_Check(args)) { + PyErr_SetString(PyExc_TypeError, "test needs ndarray input"); + return NULL; + } + PyArray_DiscardWritebackIfCopy((PyArrayObject*)args); + Py_RETURN_NONE; +} + #if !defined(NPY_PY3K) static PyObject * int_subclass(PyObject *dummy, PyObject *args) @@ -1857,6 +1869,9 @@ static PyMethodDef Multiarray_TestsMethods[] = { {"npy_resolve", npy_resolve, METH_O, NULL}, + {"npy_discard", + npy_discard, + METH_O, NULL}, #if !defined(NPY_PY3K) {"test_int_subclass", int_subclass, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 806a3b083..16d47839f 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -7255,6 +7255,7 @@ class TestWritebackIfCopy(object): assert_equal(arr, -100) # after resolve, the two arrays no longer reference each other assert_(not arr_wb.ctypes.data == 0) + assert_equal(arr_wb.base, None) arr_wb[:] = 100 assert_equal(arr, -100) @@ -7266,6 +7267,28 @@ class TestWritebackIfCopy(object): _multiarray_tests.npy_abuse_writebackifcopy(v) assert len(sup.log) == 1 + def test_view_discard_refcount(self): + from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_discard + arr = np.arange(9).reshape(3, 3).T + orig = arr.copy() + if HAS_REFCOUNT: + arr_cnt = sys.getrefcount(arr) + arr_wb = npy_create_writebackifcopy(arr) + assert_(arr_wb.flags.writebackifcopy) + assert_(arr_wb.base is arr) + arr_wb[:] = -100 + npy_discard(arr_wb) + # arr remains unchanged after discard + assert_equal(arr, orig) + # after discard, the two arrays no longer reference each other + assert_(not arr_wb.ctypes.data == 0) + assert_equal(arr_wb.base, None) + if HAS_REFCOUNT: + assert_equal(arr_cnt, sys.getrefcount(arr)) + arr_wb[:] = 100 + assert_equal(arr, orig) + + class TestArange(object): def test_infinite(self): assert_raises_regex( |
