diff options
author | Travis E. Oliphant <teoliphant@gmail.com> | 2012-07-17 19:37:09 -0700 |
---|---|---|
committer | Travis E. Oliphant <teoliphant@gmail.com> | 2012-07-17 19:37:09 -0700 |
commit | bc1005324566269d016ad9c17a25b43c6b9fc1de (patch) | |
tree | 3ba747f68dcc2c0403c020fd67d84dabad3cdb8c | |
parent | 578a4199a81e7464011661fcf8d46a8af2235db2 (diff) | |
parent | a03e8b4d286e91ef5823c059dcfb7a52ce420725 (diff) | |
download | numpy-bc1005324566269d016ad9c17a25b43c6b9fc1de.tar.gz |
Merge pull request #350 from jayvius/get-view2
Add transition code for returning view when selecting subset of fields
-rw-r--r-- | doc/release/1.7.0-notes.rst | 7 | ||||
-rw-r--r-- | doc/source/reference/arrays.indexing.rst | 10 | ||||
-rw-r--r-- | numpy/core/_internal.py | 15 | ||||
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 8 | ||||
-rw-r--r-- | numpy/core/src/multiarray/mapping.c | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 60 |
6 files changed, 86 insertions, 18 deletions
diff --git a/doc/release/1.7.0-notes.rst b/doc/release/1.7.0-notes.rst index e8b1de72d..f8f54219c 100644 --- a/doc/release/1.7.0-notes.rst +++ b/doc/release/1.7.0-notes.rst @@ -26,6 +26,13 @@ functions. To facilitate this transition, numpy 1.7 produces a FutureWarning if it detects that you may be attempting to write to such an array. See the documentation for np.diagonal for details. +Similar to np.diagonal above, in a future version of numpy, indexing +a record array by a list of field names will return a view onto the +original array, instead of producing a copy as they do now. As with +np.diagonal, numpy 1.7 produces a FutureWarning if it detects +that you may be attemping to write to such an array. See the documentation +for array indexing for details. + The default casting rule for UFunc out= parameters has been changed from 'unsafe' to 'same_kind'. Most usages which violate the 'same_kind' rule are likely bugs, so this change may expose previously undetected diff --git a/doc/source/reference/arrays.indexing.rst b/doc/source/reference/arrays.indexing.rst index 8da4ecca7..f8966f5c1 100644 --- a/doc/source/reference/arrays.indexing.rst +++ b/doc/source/reference/arrays.indexing.rst @@ -335,6 +335,16 @@ sub-array) but of data type ``x.dtype['field-name']`` and contains only the part of the data in the specified field. Also record array scalars can be "indexed" this way. +Indexing into a record array can also be done with a list of field names, +*e.g.* ``x[['field-name1','field-name2']]``. Currently this returns a new +array containing a copy of the values in the fields specified in the list. +As of NumPy 1.7, returning a copy is being deprecated in favor of returning +a view. A copy will continue to be returned for now, but a FutureWarning +will be issued when writing to the copy. If you depend on the current +behavior, then we suggest copying the returned array explicitly, i.e. use +x[['field-name1','field-name2']].copy(). This will work with both past and +future versions of NumPy. + If the accessed field is a sub-array, the dimensions of the sub-array are appended to the shape of the result. diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 309b53c44..92ab0c8b0 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -287,18 +287,15 @@ def _newnames(datatype, order): def _index_fields(ary, fields): from multiarray import empty, dtype dt = ary.dtype - new_dtype = [(name, dt[name]) for name in fields if name in dt.names] - if ary.flags.f_contiguous: - order = 'F' - else: - order = 'C' - newarray = empty(ary.shape, dtype=new_dtype, order=order) + names = [name for name in fields if name in dt.names] + formats = [dt.fields[name][0] for name in fields if name in dt.names] + offsets = [dt.fields[name][1] for name in fields if name in dt.names] - for name in fields: - newarray[name] = ary[name] + view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize} + view = ary.view(dtype=view_dtype) - return newarray + return view.copy() # Given a string containing a PEP 3118 format specifier, # construct a Numpy dtype diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index e8bc6b7b6..f0e9e36a5 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -694,9 +694,11 @@ array_might_be_written(PyArrayObject *obj) { const char *msg = "Numpy has detected that you (may be) writing to an array returned\n" - "by numpy.diagonal. This code will likely break in the next numpy\n" - "release -- see numpy.diagonal docs for details. The quick fix is\n" - "to make an explicit copy (e.g., do arr.diagonal().copy())."; + "by numpy.diagonal or by selecting multiple fields in a record\n" + "array. This code will likely break in the next numpy release --\n" + "see numpy.diagonal or arrays.indexing reference docs for details.\n" + "The quick fix is to make an explicit copy (e.g., do\n" + "arr.diagonal().copy() or arr[['f0','f1']].copy())."; if (PyArray_FLAGS(obj) & NPY_ARRAY_WARN_ON_WRITE) { if (DEPRECATE_FUTUREWARNING(msg) < 0) { return -1; diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c index cdefb9982..663a3ef7f 100644 --- a/numpy/core/src/multiarray/mapping.c +++ b/numpy/core/src/multiarray/mapping.c @@ -976,6 +976,10 @@ array_subscript(PyArrayObject *self, PyObject *op) obj = PyObject_CallMethod(_numpy_internal, "_index_fields", "OO", self, op); Py_DECREF(_numpy_internal); + if (obj == NULL) { + return NULL; + } + PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE); return obj; } } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 3427800a2..e3e24fae1 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -1909,7 +1909,8 @@ class TestRecord(TestCase): def test_field_names(self): # Test unicode and 8-bit / byte strings can be used a = np.zeros((1,), dtype=[('f1', 'i4'), - ('f2', [('sf1', 'i4')])]) + ('f2', 'i4'), + ('f3', [('sf1', 'i4')])]) is_py3 = sys.version_info[0] >= 3 if is_py3: funcs = (str,) @@ -1934,12 +1935,18 @@ class TestRecord(TestCase): assert_raises(IndexError, b[0].__setitem__, fnn, 1) assert_raises(IndexError, b[0].__getitem__, fnn) # Subfield - fn2 = func('f2') + fn3 = func('f3') sfn1 = func('sf1') - b[fn2][sfn1] = 1 - assert_equal(b[fn2][sfn1], 1) - assert_raises(ValueError, b[fn2].__setitem__, fnn, 1) - assert_raises(ValueError, b[fn2].__getitem__, fnn) + b[fn3][sfn1] = 1 + assert_equal(b[fn3][sfn1], 1) + assert_raises(ValueError, b[fn3].__setitem__, fnn, 1) + assert_raises(ValueError, b[fn3].__getitem__, fnn) + # multiple Subfields + fn2 = func('f2') + b[fn2] = 3 + assert_equal(b[['f1','f2']][0].tolist(), (2, 3)) + assert_equal(b[['f2','f1']][0].tolist(), (3, 2)) + assert_equal(b[['f1','f3']][0].tolist(), (2, (1,))) # non-ascii unicode field indexing is well behaved if not is_py3: raise SkipTest('non ascii unicode field indexing skipped; ' @@ -1948,6 +1955,47 @@ class TestRecord(TestCase): assert_raises(ValueError, a.__setitem__, u'\u03e0', 1) assert_raises(ValueError, a.__getitem__, u'\u03e0') + def test_field_names_deprecation(self): + import warnings + from numpy.testing.utils import WarningManager + def collect_warning_types(f, *args, **kwargs): + ctx = WarningManager(record=True) + warning_log = ctx.__enter__() + warnings.simplefilter("always") + try: + f(*args, **kwargs) + finally: + ctx.__exit__() + return [w.category for w in warning_log] + a = np.zeros((1,), dtype=[('f1', 'i4'), + ('f2', 'i4'), + ('f3', [('sf1', 'i4')])]) + a['f1'][0] = 1 + a['f2'][0] = 2 + a['f3'][0] = (3,) + b = np.zeros((1,), dtype=[('f1', 'i4'), + ('f2', 'i4'), + ('f3', [('sf1', 'i4')])]) + b['f1'][0] = 1 + b['f2'][0] = 2 + b['f3'][0] = (3,) + + # All the different functions raise a warning, but not an error, and + # 'a' is not modified: + assert_equal(collect_warning_types(a[['f1','f2']].__setitem__, 0, (10,20)), + [FutureWarning]) + assert_equal(a, b) + # Views also warn + subset = a[['f1','f2']] + subset_view = subset.view() + assert_equal(collect_warning_types(subset_view['f1'].__setitem__, 0, 10), + [FutureWarning]) + # But the write goes through: + assert_equal(subset['f1'][0], 10) + # Only one warning per multiple field indexing, though (even if there are + # multiple views involved): + assert_equal(collect_warning_types(subset['f1'].__setitem__, 0, 10), + []) class TestView(TestCase): def test_basic(self): |