summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2018-03-29 13:42:40 +0300
committermattip <matti.picus@gmail.com>2018-04-21 23:53:44 +0300
commit05d94b9f59f2ca8e9dbc82fd01ac31a6b6aa34d7 (patch)
treee1097b6cfd346abe13aff3f2d0064290467e153d /numpy/core
parente0b5e8740efe6d42c909c1374494e614592c65ab (diff)
downloadnumpy-05d94b9f59f2ca8e9dbc82fd01ac31a6b6aa34d7.tar.gz
BUG: test, fix PyArray_DiscardWritebackIfCopy refcount issue and document
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/include/numpy/ndarrayobject.h9
-rw-r--r--numpy/core/src/multiarray/_multiarray_tests.c.src15
-rw-r--r--numpy/core/tests/test_multiarray.py23
3 files changed, 45 insertions, 2 deletions
diff --git a/numpy/core/include/numpy/ndarrayobject.h b/numpy/core/include/numpy/ndarrayobject.h
index ec0fd1ee9..97e41b6f3 100644
--- a/numpy/core/include/numpy/ndarrayobject.h
+++ b/numpy/core/include/numpy/ndarrayobject.h
@@ -170,14 +170,19 @@ extern "C" CONFUSE_EMACS
(k)*PyArray_STRIDES(obj)[2] + \
(l)*PyArray_STRIDES(obj)[3]))
+/* Move to arrayobject.c once PyArray_XDECREF_ERR is removed */
static NPY_INLINE void
PyArray_DiscardWritebackIfCopy(PyArrayObject *arr)
{
if (arr != NULL) {
+ PyArrayObject_fields *fa = (PyArrayObject_fields *)arr;
if ((PyArray_FLAGS(arr) & NPY_ARRAY_WRITEBACKIFCOPY) ||
(PyArray_FLAGS(arr) & NPY_ARRAY_UPDATEIFCOPY)) {
- PyArrayObject *base = (PyArrayObject *)PyArray_BASE(arr);
- PyArray_ENABLEFLAGS(base, NPY_ARRAY_WRITEABLE);
+ if (fa->base) {
+ PyArray_ENABLEFLAGS((PyArrayObject*)fa->base, NPY_ARRAY_WRITEABLE);
+ Py_DECREF(fa->base);
+ fa->base = NULL;
+ }
PyArray_CLEARFLAGS(arr, NPY_ARRAY_WRITEBACKIFCOPY);
PyArray_CLEARFLAGS(arr, NPY_ARRAY_UPDATEIFCOPY);
}
diff --git a/numpy/core/src/multiarray/_multiarray_tests.c.src b/numpy/core/src/multiarray/_multiarray_tests.c.src
index 38698887a..0299f1a1b 100644
--- a/numpy/core/src/multiarray/_multiarray_tests.c.src
+++ b/numpy/core/src/multiarray/_multiarray_tests.c.src
@@ -687,6 +687,18 @@ npy_resolve(PyObject* NPY_UNUSED(self), PyObject* args)
Py_RETURN_NONE;
}
+/* resolve WRITEBACKIFCOPY */
+static PyObject*
+npy_discard(PyObject* NPY_UNUSED(self), PyObject* args)
+{
+ if (!PyArray_Check(args)) {
+ PyErr_SetString(PyExc_TypeError, "test needs ndarray input");
+ return NULL;
+ }
+ PyArray_DiscardWritebackIfCopy((PyArrayObject*)args);
+ Py_RETURN_NONE;
+}
+
#if !defined(NPY_PY3K)
static PyObject *
int_subclass(PyObject *dummy, PyObject *args)
@@ -1857,6 +1869,9 @@ static PyMethodDef Multiarray_TestsMethods[] = {
{"npy_resolve",
npy_resolve,
METH_O, NULL},
+ {"npy_discard",
+ npy_discard,
+ METH_O, NULL},
#if !defined(NPY_PY3K)
{"test_int_subclass",
int_subclass,
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 806a3b083..16d47839f 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -7255,6 +7255,7 @@ class TestWritebackIfCopy(object):
assert_equal(arr, -100)
# after resolve, the two arrays no longer reference each other
assert_(not arr_wb.ctypes.data == 0)
+ assert_equal(arr_wb.base, None)
arr_wb[:] = 100
assert_equal(arr, -100)
@@ -7266,6 +7267,28 @@ class TestWritebackIfCopy(object):
_multiarray_tests.npy_abuse_writebackifcopy(v)
assert len(sup.log) == 1
+ def test_view_discard_refcount(self):
+ from numpy.core._multiarray_tests import npy_create_writebackifcopy, npy_discard
+ arr = np.arange(9).reshape(3, 3).T
+ orig = arr.copy()
+ if HAS_REFCOUNT:
+ arr_cnt = sys.getrefcount(arr)
+ arr_wb = npy_create_writebackifcopy(arr)
+ assert_(arr_wb.flags.writebackifcopy)
+ assert_(arr_wb.base is arr)
+ arr_wb[:] = -100
+ npy_discard(arr_wb)
+ # arr remains unchanged after discard
+ assert_equal(arr, orig)
+ # after discard, the two arrays no longer reference each other
+ assert_(not arr_wb.ctypes.data == 0)
+ assert_equal(arr_wb.base, None)
+ if HAS_REFCOUNT:
+ assert_equal(arr_cnt, sys.getrefcount(arr))
+ arr_wb[:] = 100
+ assert_equal(arr, orig)
+
+
class TestArange(object):
def test_infinite(self):
assert_raises_regex(