summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathaniel J. Smith <njs@pobox.com>2012-05-15 22:12:41 +0100
committerNathaniel J. Smith <njs@pobox.com>2012-05-16 14:11:28 +0100
commitbea52bf307782b2a211b7fcfa6696fad45dae275 (patch)
tree435834d34ff9dd7a12f5acc62b45cbf77d623453
parentd403fed2423caec4149937fe48781ac68b21fddb (diff)
downloadnumpy-bea52bf307782b2a211b7fcfa6696fad45dae275.tar.gz
Transition scheme for allowing PyArray_Diagonal to return a view
PyArray_Diagonal is changed to return a copy of the diagonal (as in numpy 1.6 and earlier), but with a new (hidden) WARN_ON_WRITE flag set. Writes to this array (or views thereof) will continue to work as normal, but the first write will trigger a DeprecationWarning. We also issue this warning if someone extracts a non-numpy writeable view of the array (e.g., by accessing the Python-level .data attribute). There are likely still places where the data buffer is exposed that I've missed -- review welcome! New known-fail test: eye() for maskna arrays was only implemented by exploiting ndarray.diagonal's view-ness, so it is now unimplemented again, and the corresponding test is marked known-fail.
-rw-r--r--numpy/core/include/numpy/ndarraytypes.h11
-rw-r--r--numpy/core/numeric.py5
-rw-r--r--numpy/core/src/multiarray/arrayobject.c45
-rw-r--r--numpy/core/src/multiarray/arrayobject.h3
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src4
-rw-r--r--numpy/core/src/multiarray/buffer.c4
-rw-r--r--numpy/core/src/multiarray/getset.c13
-rw-r--r--numpy/core/src/multiarray/item_selection.c15
-rw-r--r--numpy/core/tests/test_maskna.py6
-rw-r--r--numpy/core/tests/test_multiarray.py101
-rw-r--r--numpy/lib/twodim_base.py8
11 files changed, 205 insertions, 10 deletions
diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h
index db5257761..2d591499b 100644
--- a/numpy/core/include/numpy/ndarraytypes.h
+++ b/numpy/core/include/numpy/ndarraytypes.h
@@ -917,6 +917,15 @@ typedef int (PyArray_FinalizeFunc)(PyArrayObject *, PyObject *);
*/
#define NPY_ARRAY_ALLOWNA 0x8000
+/*
+ * This flag is used internally to mark arrays which we would like to, in the
+ * future, turn into views. It causes a warning to be issued on the first
+ * attempt to write to the array (but the write is allowed to succeed).
+ *
+ * Currently it is set on arrays returned by ndarray.diagonal.
+ */
+#define NPY_ARRAY_WARN_ON_WRITE 0x10000
+
#define NPY_ARRAY_BEHAVED (NPY_ARRAY_ALIGNED | \
NPY_ARRAY_WRITEABLE)
@@ -1550,7 +1559,7 @@ static NPY_INLINE PyObject *
PyArray_GETITEM(const PyArrayObject *arr, const char *itemptr)
{
return ((PyArrayObject_fields *)arr)->descr->f->getitem(
- (void *)itemptr, (PyArrayObject *)arr);
+ (void *)itemptr, (PyArrayObject *)arr);
}
static NPY_INLINE int
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 9a51229a1..aa7d2c29b 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1956,9 +1956,8 @@ def identity(n, dtype=None, maskna=False):
[ 0., 0., 1.]])
"""
- a = zeros((n,n), dtype=dtype, maskna=maskna)
- a.diagonal()[...] = 1
- return a
+ from numpy import eye
+ return eye(n, dtype=dtype, maskna=maskna)
def allclose(a, b, rtol=1.e-5, atol=1.e-8):
"""
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 460db4b73..ff755f536 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -91,6 +91,13 @@ PyArray_SetUpdateIfCopyBase(PyArrayObject *arr, PyArrayObject *base)
goto fail;
}
+ /* Any writes to 'arr' will magicaly turn into writes to 'base', so we
+ * should warn if necessary.
+ */
+ if (PyArray_FLAGS(base) & NPY_ARRAY_WARN_ON_WRITE) {
+ PyArray_ENABLEFLAGS(arr, NPY_ARRAY_WARN_ON_WRITE);
+ }
+
/* Unlike PyArray_SetBaseObject, we do not compress the chain of base
references.
*/
@@ -143,6 +150,11 @@ PyArray_SetBaseObject(PyArrayObject *arr, PyObject *obj)
PyArrayObject *obj_arr = (PyArrayObject *)obj;
PyObject *tmp;
+ /* Propagate WARN_ON_WRITE through views. */
+ if (PyArray_FLAGS(obj_arr) & NPY_ARRAY_WARN_ON_WRITE) {
+ PyArray_ENABLEFLAGS(arr, NPY_ARRAY_WARN_ON_WRITE);
+ }
+
/* If this array owns its own data, stop collapsing */
if (PyArray_CHKFLAGS(obj_arr, NPY_ARRAY_OWNDATA)) {
break;
@@ -743,6 +755,36 @@ PyArray_CompareString(char *s1, char *s2, size_t len)
}
+/* Call this from contexts where an array might be written to, but we have no
+ * way to tell. (E.g., when converting to a read-write buffer.)
+ */
+NPY_NO_EXPORT int
+array_might_be_written(PyArrayObject *obj)
+{
+ const char *msg = "Traditionally, numpy.diagonal, numpy.diag, and "
+ "ndarray.diagonal have returned copies of an array's diagonal. "
+ "In a future version of numpy, they will return a view onto the "
+ "existing array (like slicing does). Numpy has detected that this "
+ "code (might be) writing to an array returned by one of these "
+ "functions, which in a future release will modify your original "
+ "array. To avoid this warning, make an explicit copy, e.g. by "
+ "replacing arr.diagonal() with arr.diagonal().copy().";
+ if (PyArray_FLAGS(obj) & NPY_ARRAY_WARN_ON_WRITE) {
+ if (PyErr_WarnEx(PyExc_DeprecationWarning, msg, 1) < 0) {
+ return -1;
+ }
+ /* Only warn once per array */
+ while (1) {
+ PyArray_CLEARFLAGS(obj, NPY_ARRAY_WARN_ON_WRITE);
+ if (!PyArray_BASE(obj) || !PyArray_Check(PyArray_BASE(obj))) {
+ break;
+ }
+ obj = (PyArrayObject *)PyArray_BASE(obj);
+ }
+ }
+ return 0;
+}
+
/*NUMPY_API
*
* This function does nothing if obj is writeable, and raises an exception
@@ -761,6 +803,9 @@ PyArray_RequireWriteable(PyArrayObject *obj, const char * err)
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
+ if (array_might_be_written(obj) < 0) {
+ return -1;
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/arrayobject.h b/numpy/core/src/multiarray/arrayobject.h
index ec3361435..67f38d065 100644
--- a/numpy/core/src/multiarray/arrayobject.h
+++ b/numpy/core/src/multiarray/arrayobject.h
@@ -12,4 +12,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
NPY_NO_EXPORT PyObject *
array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op);
+NPY_NO_EXPORT int
+array_might_be_written(PyArrayObject *obj);
+
#endif
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index 6dcc7d8e9..3ad8d1a75 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -18,6 +18,7 @@
#include "usertypes.h"
#include "_datetime.h"
#include "na_object.h"
+#include "arrayobject.h"
#include "numpyos.h"
@@ -649,6 +650,9 @@ VOID_getitem(char *ip, PyArrayObject *ap)
* current item a view of it
*/
if (PyArray_ISWRITEABLE(ap)) {
+ if (array_might_be_written(ap) < 0) {
+ return NULL;
+ }
u = (PyArrayObject *)PyBuffer_FromReadWriteMemory(ip, itemsize);
}
else {
diff --git a/numpy/core/src/multiarray/buffer.c b/numpy/core/src/multiarray/buffer.c
index 8389c70b0..3f58451c8 100644
--- a/numpy/core/src/multiarray/buffer.c
+++ b/numpy/core/src/multiarray/buffer.c
@@ -13,6 +13,7 @@
#include "buffer.h"
#include "numpyos.h"
+#include "arrayobject.h"
/*************************************************************************
**************** Implement Buffer Protocol ****************************
@@ -630,6 +631,9 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
if (PyArray_RequireWriteable(self, NULL) < 0) {
goto fail;
}
+ if (array_might_be_written(self) < 0) {
+ goto fail;
+ }
}
if (view == NULL) {
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c
index 1ebdb53b2..208a8aa20 100644
--- a/numpy/core/src/multiarray/getset.c
+++ b/numpy/core/src/multiarray/getset.c
@@ -16,6 +16,7 @@
#include "scalartypes.h"
#include "descriptor.h"
#include "getset.h"
+#include "arrayobject.h"
/******************* array attribute get and set routines ******************/
@@ -260,6 +261,10 @@ array_interface_get(PyArrayObject *self)
return NULL;
}
+ if (array_might_be_written(self) < 0) {
+ return NULL;
+ }
+
/* dataptr */
obj = array_dataptr_get(self);
PyDict_SetItemString(dict, "data", obj);
@@ -302,6 +307,9 @@ array_data_get(PyArrayObject *self)
}
nbytes = PyArray_NBYTES(self);
if (PyArray_ISWRITEABLE(self)) {
+ if (array_might_be_written(self) < 0) {
+ return NULL;
+ }
return PyBuffer_FromReadWriteObject((PyObject *)self, 0, (Py_ssize_t) nbytes);
}
else {
@@ -557,6 +565,11 @@ array_struct_get(PyArrayObject *self)
PyArrayInterface *inter;
PyObject *ret;
+ if (PyArray_ISWRITEABLE(self)) {
+ if (array_might_be_written(self) < 0) {
+ return NULL;
+ }
+ }
inter = (PyArrayInterface *)PyArray_malloc(sizeof(PyArrayInterface));
if (inter==NULL) {
return PyErr_NoMemory();
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 21892045f..92cd8ac29 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -1841,7 +1841,11 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
/*NUMPY_API
* Diagonal
*
- * As of NumPy 1.7, this function always returns a view into 'self'.
+ * In NumPy versions prior to 1.7, this function always returned a copy of
+ * the diagonal array. In 1.7, the code has been updated to compute a view
+ * onto 'self', but it still copies this array before returning, as well as
+ * setting the internal WARN_ON_WRITE flag. In a future version, it will
+ * simply return a view onto self.
*/
NPY_NO_EXPORT PyObject *
PyArray_Diagonal(PyArrayObject *self, int offset, int axis1, int axis2)
@@ -1857,6 +1861,7 @@ PyArray_Diagonal(PyArrayObject *self, int offset, int axis1, int axis2)
PyArrayObject *ret;
PyArray_Descr *dtype;
npy_intp ret_shape[NPY_MAXDIMS], ret_strides[NPY_MAXDIMS];
+ PyObject *copy;
if (ndim < 2) {
PyErr_SetString(PyExc_ValueError,
@@ -1987,7 +1992,13 @@ PyArray_Diagonal(PyArrayObject *self, int offset, int axis1, int axis2)
fret->flags |= NPY_ARRAY_MASKNA;
}
- return (PyObject *)ret;
+ /* For backwards compatibility, during the deprecation period: */
+ copy = PyArray_NewCopy(ret, NPY_KEEPORDER);
+ if (!copy) {
+ return NULL;
+ }
+ PyArray_ENABLEFLAGS((PyArrayObject *)copy, NPY_ARRAY_WARN_ON_WRITE);
+ return copy;
}
/*NUMPY_API
diff --git a/numpy/core/tests/test_maskna.py b/numpy/core/tests/test_maskna.py
index 5aeb1d668..0b48ab803 100644
--- a/numpy/core/tests/test_maskna.py
+++ b/numpy/core/tests/test_maskna.py
@@ -1408,11 +1408,9 @@ def test_array_maskna_diagonal():
a.shape = (2,3)
a[0,1] = np.NA
- # Should produce a view into a
res = a.diagonal()
- assert_(res.base is a)
assert_(res.flags.maskna)
- assert_(not res.flags.ownmaskna)
+ assert_(res.flags.ownmaskna)
assert_equal(res, [0, 4])
res = a.diagonal(-1)
@@ -1584,6 +1582,8 @@ def test_array_maskna_linspace_logspace():
assert_(b.flags.maskna)
+from numpy.testing import dec
+@dec.knownfailureif(True, "eye is not implemented for maskna")
def test_array_maskna_eye_identity():
# np.eye
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index bfbc69d96..c315ea385 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -811,6 +811,107 @@ class TestMethods(TestCase):
# Order of axis argument doesn't matter:
assert_equal(b.diagonal(0, 2, 1), [[0, 3], [4, 7]])
+ def test_diagonal_deprecation(self):
+ import warnings
+ from numpy.testing.utils import WarningManager
+ def collect_warning_types(f, *args, **kwargs):
+ ctx = WarningManager(record=True)
+ warning_log = ctx.__enter__()
+ warnings.simplefilter("always")
+ try:
+ f(*args, **kwargs)
+ finally:
+ ctx.__exit__()
+ return [w.category for w in warning_log]
+ a = np.arange(9).reshape(3, 3)
+ # All the different functions raise a warning, but not an error, and
+ # 'a' is not modified:
+ assert_equal(collect_warning_types(a.diagonal().__setitem__, 0, 10),
+ [DeprecationWarning])
+ assert_equal(a, np.arange(9).reshape(3, 3))
+ assert_equal(collect_warning_types(np.diagonal(a).__setitem__, 0, 10),
+ [DeprecationWarning])
+ assert_equal(a, np.arange(9).reshape(3, 3))
+ assert_equal(collect_warning_types(np.diag(a).__setitem__, 0, 10),
+ [DeprecationWarning])
+ assert_equal(a, np.arange(9).reshape(3, 3))
+ # Views also warn
+ d = np.diagonal(a)
+ d_view = d.view()
+ assert_equal(collect_warning_types(d_view.__setitem__, 0, 10),
+ [DeprecationWarning])
+ # But the write goes through:
+ assert_equal(d[0], 10)
+ # Only one warning per call to diagonal, though (even if there are
+ # multiple views involved):
+ assert_equal(collect_warning_types(d.__setitem__, 0, 10),
+ [])
+
+ # Other ways of accessing the data also warn:
+ # .data gives a read-write buffer:
+ assert_equal(collect_warning_types(getattr, a.diagonal(), "data"),
+ [DeprecationWarning])
+ # Void dtypes can give us a read-write buffer, but only in Python 2:
+ import sys
+ if sys.version_info[0] < 3:
+ aV = np.empty((3, 3), dtype="V10")
+ assert_equal(collect_warning_types(aV.diagonal().item, 0),
+ [DeprecationWarning])
+ # XX it seems that direct indexing of a void object returns a void
+ # scalar, which ignores not just WARN_ON_WRITE but even WRITEABLE.
+ # i.e. in this:
+ # a = np.empty(10, dtype="V10")
+ # a.flags.writeable = False
+ # buf = a[0].item()
+ # 'buf' ends up as a writeable buffer. I guess no-one actually
+ # uses void types like this though...
+ # __array_interface also lets a data pointer get away from us
+ log = collect_warning_types(getattr, a.diagonal(),
+ "__array_interface__")
+ assert_equal(log, [DeprecationWarning])
+ # ctypeslib goes via __array_interface__:
+ log = collect_warning_types(np.ctypeslib.as_ctypes, a.diagonal())
+ assert_equal(log, [DeprecationWarning])
+ # __array_struct__
+ log = collect_warning_types(getattr, a.diagonal(), "__array_struct__")
+ assert_equal(log, [DeprecationWarning])
+ # PEP 3118:
+ if hasattr(__builtins__, "memoryview"):
+ assert_equal(collect_warning_types(memoryview, a.diagonal()),
+ [DeprecationWarning])
+
+ # Make sure that our recommendation to silence the warning by copying
+ # the array actually works:
+ diag_copy = a.diagonal().copy()
+ assert_equal(collect_warning_types(diag_copy.__setitem__, 0, 10),
+ [])
+ # There might be people who get a spurious warning because they are
+ # extracting a buffer, but then use that buffer in a read-only
+ # fashion. And they might get cranky at having to create a superfluous
+ # copy just to work around this spurious warning. A reasonable
+ # solution would be for them to mark their usage as read-only, and
+ # thus safe for both past and future PyArray_Diagonal
+ # semantics. So let's make sure that setting the diagonal array to
+ # non-writeable will suppress these warnings:
+ ro_diag = a.diagonal()
+ ro_diag.flags.writeable = False
+ assert_equal(collect_warning_types(getattr, ro_diag, "data"), [])
+ # __array_interface__ has no way to communicate read-onlyness --
+ # effectively all __array_interface__ arrays are assumed to be
+ # writeable :-(
+ # ro_diag = a.diagonal()
+ # ro_diag.flags.writeable = False
+ # assert_equal(collect_warning_types(getattr, ro_diag,
+ # "__array_interface__"), [])
+ ro_diag = a.diagonal()
+ ro_diag.flags.writeable = False
+ assert_equal(collect_warning_types(memoryview, ro_diag), [])
+ ro_diag = a.diagonal()
+ ro_diag.flags.writeable = False
+ assert_equal(collect_warning_types(getattr, ro_diag,
+ "__array_struct__"), [])
+
+
def test_ravel(self):
a = np.array([[0,1],[2,3]])
assert_equal(a.ravel(), [0,1,2,3])
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 58d8250a1..2b518aeae 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -210,7 +210,13 @@ def eye(N, M=None, k=0, dtype=float, maskna=False):
if M is None:
M = N
m = zeros((N, M), dtype=dtype, maskna=maskna)
- diagonal(m, k)[...] = 1
+ if k >= M:
+ return m
+ if k >= 0:
+ i = k
+ else:
+ i = (-k) * M
+ m[:M-k].flat[i::M+1] = 1
return m
def diag(v, k=0):