diff options
author | Erik M. Bray <embray@stsci.edu> | 2015-05-27 11:06:17 -0400 |
---|---|---|
committer | Erik M. Bray <embray@stsci.edu> | 2015-06-01 12:51:15 -0400 |
commit | 776c3c32b759a62542939fa2abd729f9ad77ae9b (patch) | |
tree | afedfb26df99413c16f7135ad0431f5fe2fd3966 | |
parent | 9e7a0b224b4a7859efc1bc80b7ccc5c86c83d3df (diff) | |
download | numpy-776c3c32b759a62542939fa2abd729f9ad77ae9b.tar.gz |
BUG: Further fixes to record and recarray getitem/getattr
This is a followup to PR #5505, which didn't go quite far
enough. This fixes two issues in particular:
1) The record class also needs an updated `__getitem__` that works
analogously to its `__getattribute__` so that a nested record is
returned as a `record` object and not a plain `np.void`. In other
words the old behavior is:
```python
>>> rec = np.rec.array([('abc ', (1,1), 1), ('abc', (2,3), 1)],
... dtype=[('foo', 'S4'), ('bar', [('A', int), ('B', int)]),
('baz', int)])
>>> rec[0].bar
(1, 1)
>>> type(rec[0].bar)
<class 'numpy.record'>
>>> type(rec[0]['bar'])
<type 'numpy.void'>
```
demonstrated inconsistency between `.bar` and `['bar']` on the record
object. The new behavior is:
```python
>>> type(rec[0]['bar'])
<class 'numpy.record'>
```
2) The second issue is more subtle. The fix to #5505 used the
`obj.dtype.descr` attribute to create a new dtype of type `record`.
However, this does not recreate the correct type if the fields are
not aligned. To demonstrate:
```python
>>> dt = np.dtype({'C': ('S5', 0), 'D': ('S5', 6)})
>>> dt.fields
dict_proxy({'C': (dtype('S5'), 0), 'D': (dtype('S5'), 6)})
>>> dt.descr
[('C', '|S5'), ('', '|V1'), ('D', '|S5')]
>>> new_dt = np.dtype((np.record, dt.descr))
>>> new_dt
dtype((numpy.record, [('C', 'S5'), ('f1', 'V1'), ('D', 'S5')]))
>>> new_dt.fields
dict_proxy({'f1': (dtype('V1'), 5), 'C': (dtype('S5'), 0), 'D':
(dtype('S5'), 6)})
```
Using the `fields` dict to construct the new type reconstructs the
correct type with the correct offsets:
```python
>>> new_dt2 = np.dtype((np.record, dt.fields))
>>> new_dt2.fields
dict_proxy({'C': (dtype('S5'), 0), 'D': (dtype('S5'), 6)})
```
(Note: This is based on #5920 for convenience, but I could decouple
the changes if that's preferable.)
-rw-r--r-- | numpy/core/records.py | 17 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 17 |
2 files changed, 28 insertions, 6 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py index 243645436..b1ea176e4 100644 --- a/numpy/core/records.py +++ b/numpy/core/records.py @@ -245,13 +245,12 @@ class record(nt.void): #happens if field is Object type return obj if dt.fields: - return obj.view((record, obj.dtype.descr)) + return obj.view((self.__class__, obj.dtype.fields)) return obj else: raise AttributeError("'record' object has no " "attribute '%s'" % attr) - def __setattr__(self, attr, val): if attr in ['setfield', 'getfield', 'dtype']: raise AttributeError("Cannot set '%s' attribute" % attr) @@ -266,6 +265,16 @@ class record(nt.void): raise AttributeError("'record' object has no " "attribute '%s'" % attr) + def __getitem__(self, indx): + obj = nt.void.__getitem__(self, indx) + + # copy behavior of record.__getattribute__, + if isinstance(obj, nt.void) and obj.dtype.fields: + return obj.view((self.__class__, obj.dtype.fields)) + else: + # return a single element + return obj + def pprint(self): """Pretty-print all fields.""" # pretty-print all fields @@ -438,7 +447,7 @@ class recarray(ndarray): # to preserve numpy.record type if present), since nested structured # fields do not inherit type. if obj.dtype.fields: - return obj.view(dtype=(self.dtype.type, obj.dtype.descr)) + return obj.view(dtype=(self.dtype.type, obj.dtype.fields)) else: return obj.view(ndarray) @@ -478,7 +487,7 @@ class recarray(ndarray): # we might also be returning a single element if isinstance(obj, ndarray): if obj.dtype.fields: - return obj.view(dtype=(self.dtype.type, obj.dtype.descr)) + return obj.view(dtype=(self.dtype.type, obj.dtype.fields)) else: return obj.view(type=ndarray) else: diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index a7895a30a..2f86b7ccb 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -149,19 +149,32 @@ class TestFromrecords(TestCase): assert_equal(a.foo[0] == a.foo[1], False) def test_recarray_returntypes(self): - a = np.rec.array([('abc ', (1,1), 1), ('abc', (2,3), 1)], + qux_fields = {'C': (np.dtype('S5'), 0), 'D': (np.dtype('S5'), 6)} + a = np.rec.array([('abc ', (1,1), 1, ('abcde', 'fgehi')), + ('abc', (2,3), 1, ('abcde', 'jklmn'))], dtype=[('foo', 'S4'), ('bar', [('A', int), ('B', int)]), - ('baz', int)]) + ('baz', int), ('qux', qux_fields)]) assert_equal(type(a.foo), np.ndarray) assert_equal(type(a['foo']), np.ndarray) assert_equal(type(a.bar), np.recarray) assert_equal(type(a['bar']), np.recarray) assert_equal(a.bar.dtype.type, np.record) + assert_equal(type(a['qux']), np.recarray) + assert_equal(a.qux.dtype.type, np.record) + assert_equal(dict(a.qux.dtype.fields), qux_fields) assert_equal(type(a.baz), np.ndarray) assert_equal(type(a['baz']), np.ndarray) assert_equal(type(a[0].bar), np.record) + assert_equal(type(a[0]['bar']), np.record) assert_equal(a[0].bar.A, 1) + assert_equal(a[0].bar['A'], 1) + assert_equal(a[0]['bar'].A, 1) + assert_equal(a[0]['bar']['A'], 1) + assert_equal(a[0].qux.D, asbytes('fgehi')) + assert_equal(a[0].qux['D'], asbytes('fgehi')) + assert_equal(a[0]['qux'].D, asbytes('fgehi')) + assert_equal(a[0]['qux']['D'], asbytes('fgehi')) class TestRecord(TestCase): |