summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTravis E. Oliphant <teoliphant@gmail.com>2012-07-17 19:37:09 -0700
committerTravis E. Oliphant <teoliphant@gmail.com>2012-07-17 19:37:09 -0700
commitbc1005324566269d016ad9c17a25b43c6b9fc1de (patch)
tree3ba747f68dcc2c0403c020fd67d84dabad3cdb8c /numpy
parent578a4199a81e7464011661fcf8d46a8af2235db2 (diff)
parenta03e8b4d286e91ef5823c059dcfb7a52ce420725 (diff)
downloadnumpy-bc1005324566269d016ad9c17a25b43c6b9fc1de.tar.gz
Merge pull request #350 from jayvius/get-view2
Add transition code for returning view when selecting subset of fields
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_internal.py15
-rw-r--r--numpy/core/src/multiarray/arrayobject.c8
-rw-r--r--numpy/core/src/multiarray/mapping.c4
-rw-r--r--numpy/core/tests/test_multiarray.py60
4 files changed, 69 insertions, 18 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 309b53c44..92ab0c8b0 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -287,18 +287,15 @@ def _newnames(datatype, order):
def _index_fields(ary, fields):
from multiarray import empty, dtype
dt = ary.dtype
- new_dtype = [(name, dt[name]) for name in fields if name in dt.names]
- if ary.flags.f_contiguous:
- order = 'F'
- else:
- order = 'C'
- newarray = empty(ary.shape, dtype=new_dtype, order=order)
+ names = [name for name in fields if name in dt.names]
+ formats = [dt.fields[name][0] for name in fields if name in dt.names]
+ offsets = [dt.fields[name][1] for name in fields if name in dt.names]
- for name in fields:
- newarray[name] = ary[name]
+ view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize}
+ view = ary.view(dtype=view_dtype)
- return newarray
+ return view.copy()
# Given a string containing a PEP 3118 format specifier,
# construct a Numpy dtype
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index e8bc6b7b6..f0e9e36a5 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -694,9 +694,11 @@ array_might_be_written(PyArrayObject *obj)
{
const char *msg =
"Numpy has detected that you (may be) writing to an array returned\n"
- "by numpy.diagonal. This code will likely break in the next numpy\n"
- "release -- see numpy.diagonal docs for details. The quick fix is\n"
- "to make an explicit copy (e.g., do arr.diagonal().copy()).";
+ "by numpy.diagonal or by selecting multiple fields in a record\n"
+ "array. This code will likely break in the next numpy release --\n"
+ "see numpy.diagonal or arrays.indexing reference docs for details.\n"
+ "The quick fix is to make an explicit copy (e.g., do\n"
+ "arr.diagonal().copy() or arr[['f0','f1']].copy()).";
if (PyArray_FLAGS(obj) & NPY_ARRAY_WARN_ON_WRITE) {
if (DEPRECATE_FUTUREWARNING(msg) < 0) {
return -1;
diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c
index cdefb9982..663a3ef7f 100644
--- a/numpy/core/src/multiarray/mapping.c
+++ b/numpy/core/src/multiarray/mapping.c
@@ -976,6 +976,10 @@ array_subscript(PyArrayObject *self, PyObject *op)
obj = PyObject_CallMethod(_numpy_internal,
"_index_fields", "OO", self, op);
Py_DECREF(_numpy_internal);
+ if (obj == NULL) {
+ return NULL;
+ }
+ PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE);
return obj;
}
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 3427800a2..e3e24fae1 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1909,7 +1909,8 @@ class TestRecord(TestCase):
def test_field_names(self):
# Test unicode and 8-bit / byte strings can be used
a = np.zeros((1,), dtype=[('f1', 'i4'),
- ('f2', [('sf1', 'i4')])])
+ ('f2', 'i4'),
+ ('f3', [('sf1', 'i4')])])
is_py3 = sys.version_info[0] >= 3
if is_py3:
funcs = (str,)
@@ -1934,12 +1935,18 @@ class TestRecord(TestCase):
assert_raises(IndexError, b[0].__setitem__, fnn, 1)
assert_raises(IndexError, b[0].__getitem__, fnn)
# Subfield
- fn2 = func('f2')
+ fn3 = func('f3')
sfn1 = func('sf1')
- b[fn2][sfn1] = 1
- assert_equal(b[fn2][sfn1], 1)
- assert_raises(ValueError, b[fn2].__setitem__, fnn, 1)
- assert_raises(ValueError, b[fn2].__getitem__, fnn)
+ b[fn3][sfn1] = 1
+ assert_equal(b[fn3][sfn1], 1)
+ assert_raises(ValueError, b[fn3].__setitem__, fnn, 1)
+ assert_raises(ValueError, b[fn3].__getitem__, fnn)
+ # multiple Subfields
+ fn2 = func('f2')
+ b[fn2] = 3
+ assert_equal(b[['f1','f2']][0].tolist(), (2, 3))
+ assert_equal(b[['f2','f1']][0].tolist(), (3, 2))
+ assert_equal(b[['f1','f3']][0].tolist(), (2, (1,)))
# non-ascii unicode field indexing is well behaved
if not is_py3:
raise SkipTest('non ascii unicode field indexing skipped; '
@@ -1948,6 +1955,47 @@ class TestRecord(TestCase):
assert_raises(ValueError, a.__setitem__, u'\u03e0', 1)
assert_raises(ValueError, a.__getitem__, u'\u03e0')
+ def test_field_names_deprecation(self):
+ import warnings
+ from numpy.testing.utils import WarningManager
+ def collect_warning_types(f, *args, **kwargs):
+ ctx = WarningManager(record=True)
+ warning_log = ctx.__enter__()
+ warnings.simplefilter("always")
+ try:
+ f(*args, **kwargs)
+ finally:
+ ctx.__exit__()
+ return [w.category for w in warning_log]
+ a = np.zeros((1,), dtype=[('f1', 'i4'),
+ ('f2', 'i4'),
+ ('f3', [('sf1', 'i4')])])
+ a['f1'][0] = 1
+ a['f2'][0] = 2
+ a['f3'][0] = (3,)
+ b = np.zeros((1,), dtype=[('f1', 'i4'),
+ ('f2', 'i4'),
+ ('f3', [('sf1', 'i4')])])
+ b['f1'][0] = 1
+ b['f2'][0] = 2
+ b['f3'][0] = (3,)
+
+ # All the different functions raise a warning, but not an error, and
+ # 'a' is not modified:
+ assert_equal(collect_warning_types(a[['f1','f2']].__setitem__, 0, (10,20)),
+ [FutureWarning])
+ assert_equal(a, b)
+ # Views also warn
+ subset = a[['f1','f2']]
+ subset_view = subset.view()
+ assert_equal(collect_warning_types(subset_view['f1'].__setitem__, 0, 10),
+ [FutureWarning])
+ # But the write goes through:
+ assert_equal(subset['f1'][0], 10)
+ # Only one warning per multiple field indexing, though (even if there are
+ # multiple views involved):
+ assert_equal(collect_warning_types(subset['f1'].__setitem__, 0, 10),
+ [])
class TestView(TestCase):
def test_basic(self):