summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJay Bourque <jay.bourque@continuum.io>2012-08-31 12:51:44 -0500
committerJay Bourque <jay.bourque@continuum.io>2012-08-31 12:51:44 -0500
commit93100c92b57dd9663b688fdd94efc7d05ef7ff38 (patch)
treee7db39b2237b84ce593f4f5593588af074089ad1
parent68320a10f2e29a70a9a39110263c040aab689147 (diff)
downloadnumpy-93100c92b57dd9663b688fdd94efc7d05ef7ff38.tar.gz
Fix returned copy
Fix returned copy so that copy of view with offsets copies only fields in view, not all the fields from original array. Also add unit tests to make sure this doesn't break when copy is fully deprecated in favor of returning a view.
-rw-r--r--numpy/core/_internal.py6
-rw-r--r--numpy/core/tests/test_multiarray.py5
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 92ab0c8b0..fbe580dee 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -295,7 +295,11 @@ def _index_fields(ary, fields):
view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize}
view = ary.view(dtype=view_dtype)
- return view.copy()
+ # Return a copy for now until behavior is fully deprecated
+ # in favor of returning view
+ copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
+ from numpy import array
+ return array(view, dtype=copy_dtype, copy=True)
# Given a string containing a PEP 3118 format specifier,
# construct a Numpy dtype
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b9fd3ad86..118f221ae 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -1956,6 +1956,11 @@ class TestRecord(TestCase):
assert_equal(b[['f1','f2']][0].tolist(), (2, 3))
assert_equal(b[['f2','f1']][0].tolist(), (3, 2))
assert_equal(b[['f1','f3']][0].tolist(), (2, (1,)))
+ # view of subfield view/copy
+ assert_equal(b[['f1','f2']][0].view(('i4',2)).tolist(), (2, 3))
+ assert_equal(b[['f2','f1']][0].view(('i4',2)).tolist(), (3, 2))
+ view_dtype=[('f1', 'i4'),('f3', [('', 'i4')])]
+ assert_equal(b[['f1','f3']][0].view(view_dtype).tolist(), (2, (1,)))
# non-ascii unicode field indexing is well behaved
if not is_py3:
raise SkipTest('non ascii unicode field indexing skipped; '