diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-06-27 02:52:49 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-06-27 02:52:49 +0000 |
commit | 7ae3b470b424cd8f80fdf54eab22a7fa8ac127ac (patch) | |
tree | a0aeee09bedf27940e674d2445adf46186606a30 | |
parent | 41471a6e9fcf1a6fc18630b1f7d0716ba977211d (diff) | |
download | numpy-7ae3b470b424cd8f80fdf54eab22a7fa8ac127ac.tar.gz |
Add unit-test for record-arrays with object field.
-rw-r--r-- | numpy/core/records.py | 10 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 15 |
2 files changed, 23 insertions, 2 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py index 333bc486e..aebf3b3c2 100644 --- a/numpy/core/records.py +++ b/numpy/core/records.py @@ -176,7 +176,7 @@ class recarray(sb.ndarray): fielddict = sb.ndarray.__getattribute__(self,'dtype').fields try: res = fielddict[attr][:2] - except KeyError: + except (TypeError, KeyError): raise AttributeError, "record array has no attribute %s" % attr obj = self.getfield(*res) # if it has fields return a recarray, otherwise return @@ -195,10 +195,16 @@ class recarray(sb.ndarray): fielddict = sb.ndarray.__getattribute__(self,'dtype').fields try: res = fielddict[attr][:2] - except KeyError: + except (TypeError,KeyError): raise AttributeError, "record array has no attribute %s" % attr return self.setfield(val,*res) + def __getitem__(self, indx): + obj = sb.ndarray.__getitem__(self, indx) + if (isinstance(obj, sb.ndarray) and obj.dtype.isbuiltin): + return obj.view(sb.ndarray) + return obj + def field(self,attr, val=None): fielddict = sb.ndarray.__getattribute__(self,'dtype').fields diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index 4c0e8831b..b431201b5 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -39,5 +39,20 @@ class test_fromrecords(NumpyTestCase): fd.seek(2880*2) r = rec.fromfile(fd, formats='f8,i4,a5', shape=3, byteorder='big') + def check_recarray_from_obj(self): + count = 10 + a = zeros(count, dtype='O') + b = zeros(count, dtype='f8') + c = zeros(count, dtype='f8') + for i in range(len(a)): + a[i] = range(1,10) + + mine = numpy.rec.fromarrays([a,b,c], + names='date,data1,data2') + for i in range(len(a)): + assert(mine.date[i]==range(1,10)) + assert(mine.data1[i]==0.0) + assert(mine.data2[i]==0.0) + if __name__ == "__main__": NumpyTest().run() |