summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-06-27 02:52:49 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-06-27 02:52:49 +0000
commit7ae3b470b424cd8f80fdf54eab22a7fa8ac127ac (patch)
treea0aeee09bedf27940e674d2445adf46186606a30
parent41471a6e9fcf1a6fc18630b1f7d0716ba977211d (diff)
downloadnumpy-7ae3b470b424cd8f80fdf54eab22a7fa8ac127ac.tar.gz
Add unit-test for record-arrays with object field.
-rw-r--r--numpy/core/records.py10
-rw-r--r--numpy/core/tests/test_records.py15
2 files changed, 23 insertions, 2 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py
index 333bc486e..aebf3b3c2 100644
--- a/numpy/core/records.py
+++ b/numpy/core/records.py
@@ -176,7 +176,7 @@ class recarray(sb.ndarray):
fielddict = sb.ndarray.__getattribute__(self,'dtype').fields
try:
res = fielddict[attr][:2]
- except KeyError:
+ except (TypeError, KeyError):
raise AttributeError, "record array has no attribute %s" % attr
obj = self.getfield(*res)
# if it has fields return a recarray, otherwise return
@@ -195,10 +195,16 @@ class recarray(sb.ndarray):
fielddict = sb.ndarray.__getattribute__(self,'dtype').fields
try:
res = fielddict[attr][:2]
- except KeyError:
+ except (TypeError,KeyError):
raise AttributeError, "record array has no attribute %s" % attr
return self.setfield(val,*res)
+ def __getitem__(self, indx):
+ obj = sb.ndarray.__getitem__(self, indx)
+ if (isinstance(obj, sb.ndarray) and obj.dtype.isbuiltin):
+ return obj.view(sb.ndarray)
+ return obj
+
def field(self,attr, val=None):
fielddict = sb.ndarray.__getattribute__(self,'dtype').fields
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index 4c0e8831b..b431201b5 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -39,5 +39,20 @@ class test_fromrecords(NumpyTestCase):
fd.seek(2880*2)
r = rec.fromfile(fd, formats='f8,i4,a5', shape=3, byteorder='big')
+ def check_recarray_from_obj(self):
+ count = 10
+ a = zeros(count, dtype='O')
+ b = zeros(count, dtype='f8')
+ c = zeros(count, dtype='f8')
+ for i in range(len(a)):
+ a[i] = range(1,10)
+
+ mine = numpy.rec.fromarrays([a,b,c],
+ names='date,data1,data2')
+ for i in range(len(a)):
+ assert(mine.date[i]==range(1,10))
+ assert(mine.data1[i]==0.0)
+ assert(mine.data2[i]==0.0)
+
if __name__ == "__main__":
NumpyTest().run()