summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorErik M. Bray <embray@stsci.edu>2015-05-27 11:06:17 -0400
committerErik M. Bray <embray@stsci.edu>2015-06-01 12:51:15 -0400
commit776c3c32b759a62542939fa2abd729f9ad77ae9b (patch)
treeafedfb26df99413c16f7135ad0431f5fe2fd3966
parent9e7a0b224b4a7859efc1bc80b7ccc5c86c83d3df (diff)
downloadnumpy-776c3c32b759a62542939fa2abd729f9ad77ae9b.tar.gz
BUG: Further fixes to record and recarray getitem/getattr
This is a followup to PR #5505, which didn't go quite far enough. This fixes two issues in particular: 1) The record class also needs an updated `__getitem__` that works analogously to its `__getattribute__` so that a nested record is returned as a `record` object and not a plain `np.void`. In other words the old behavior is: ```python >>> rec = np.rec.array([('abc ', (1,1), 1), ('abc', (2,3), 1)], ... dtype=[('foo', 'S4'), ('bar', [('A', int), ('B', int)]), ('baz', int)]) >>> rec[0].bar (1, 1) >>> type(rec[0].bar) <class 'numpy.record'> >>> type(rec[0]['bar']) <type 'numpy.void'> ``` demonstrated inconsistency between `.bar` and `['bar']` on the record object. The new behavior is: ```python >>> type(rec[0]['bar']) <class 'numpy.record'> ``` 2) The second issue is more subtle. The fix to #5505 used the `obj.dtype.descr` attribute to create a new dtype of type `record`. However, this does not recreate the correct type if the fields are not aligned. To demonstrate: ```python >>> dt = np.dtype({'C': ('S5', 0), 'D': ('S5', 6)}) >>> dt.fields dict_proxy({'C': (dtype('S5'), 0), 'D': (dtype('S5'), 6)}) >>> dt.descr [('C', '|S5'), ('', '|V1'), ('D', '|S5')] >>> new_dt = np.dtype((np.record, dt.descr)) >>> new_dt dtype((numpy.record, [('C', 'S5'), ('f1', 'V1'), ('D', 'S5')])) >>> new_dt.fields dict_proxy({'f1': (dtype('V1'), 5), 'C': (dtype('S5'), 0), 'D': (dtype('S5'), 6)}) ``` Using the `fields` dict to construct the new type reconstructs the correct type with the correct offsets: ```python >>> new_dt2 = np.dtype((np.record, dt.fields)) >>> new_dt2.fields dict_proxy({'C': (dtype('S5'), 0), 'D': (dtype('S5'), 6)}) ``` (Note: This is based on #5920 for convenience, but I could decouple the changes if that's preferable.)
-rw-r--r--numpy/core/records.py17
-rw-r--r--numpy/core/tests/test_records.py17
2 files changed, 28 insertions, 6 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py
index 243645436..b1ea176e4 100644
--- a/numpy/core/records.py
+++ b/numpy/core/records.py
@@ -245,13 +245,12 @@ class record(nt.void):
#happens if field is Object type
return obj
if dt.fields:
- return obj.view((record, obj.dtype.descr))
+ return obj.view((self.__class__, obj.dtype.fields))
return obj
else:
raise AttributeError("'record' object has no "
"attribute '%s'" % attr)
-
def __setattr__(self, attr, val):
if attr in ['setfield', 'getfield', 'dtype']:
raise AttributeError("Cannot set '%s' attribute" % attr)
@@ -266,6 +265,16 @@ class record(nt.void):
raise AttributeError("'record' object has no "
"attribute '%s'" % attr)
+ def __getitem__(self, indx):
+ obj = nt.void.__getitem__(self, indx)
+
+ # copy behavior of record.__getattribute__,
+ if isinstance(obj, nt.void) and obj.dtype.fields:
+ return obj.view((self.__class__, obj.dtype.fields))
+ else:
+ # return a single element
+ return obj
+
def pprint(self):
"""Pretty-print all fields."""
# pretty-print all fields
@@ -438,7 +447,7 @@ class recarray(ndarray):
# to preserve numpy.record type if present), since nested structured
# fields do not inherit type.
if obj.dtype.fields:
- return obj.view(dtype=(self.dtype.type, obj.dtype.descr))
+ return obj.view(dtype=(self.dtype.type, obj.dtype.fields))
else:
return obj.view(ndarray)
@@ -478,7 +487,7 @@ class recarray(ndarray):
# we might also be returning a single element
if isinstance(obj, ndarray):
if obj.dtype.fields:
- return obj.view(dtype=(self.dtype.type, obj.dtype.descr))
+ return obj.view(dtype=(self.dtype.type, obj.dtype.fields))
else:
return obj.view(type=ndarray)
else:
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index a7895a30a..2f86b7ccb 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -149,19 +149,32 @@ class TestFromrecords(TestCase):
assert_equal(a.foo[0] == a.foo[1], False)
def test_recarray_returntypes(self):
- a = np.rec.array([('abc ', (1,1), 1), ('abc', (2,3), 1)],
+ qux_fields = {'C': (np.dtype('S5'), 0), 'D': (np.dtype('S5'), 6)}
+ a = np.rec.array([('abc ', (1,1), 1, ('abcde', 'fgehi')),
+ ('abc', (2,3), 1, ('abcde', 'jklmn'))],
dtype=[('foo', 'S4'),
('bar', [('A', int), ('B', int)]),
- ('baz', int)])
+ ('baz', int), ('qux', qux_fields)])
assert_equal(type(a.foo), np.ndarray)
assert_equal(type(a['foo']), np.ndarray)
assert_equal(type(a.bar), np.recarray)
assert_equal(type(a['bar']), np.recarray)
assert_equal(a.bar.dtype.type, np.record)
+ assert_equal(type(a['qux']), np.recarray)
+ assert_equal(a.qux.dtype.type, np.record)
+ assert_equal(dict(a.qux.dtype.fields), qux_fields)
assert_equal(type(a.baz), np.ndarray)
assert_equal(type(a['baz']), np.ndarray)
assert_equal(type(a[0].bar), np.record)
+ assert_equal(type(a[0]['bar']), np.record)
assert_equal(a[0].bar.A, 1)
+ assert_equal(a[0].bar['A'], 1)
+ assert_equal(a[0]['bar'].A, 1)
+ assert_equal(a[0]['bar']['A'], 1)
+ assert_equal(a[0].qux.D, asbytes('fgehi'))
+ assert_equal(a[0].qux['D'], asbytes('fgehi'))
+ assert_equal(a[0]['qux'].D, asbytes('fgehi'))
+ assert_equal(a[0]['qux']['D'], asbytes('fgehi'))
class TestRecord(TestCase):