diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2007-08-21 06:58:53 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2007-08-21 06:58:53 +0000 |
commit | 859439e5ebcdab8f5a5ced4dd42be537e4b037e5 (patch) | |
tree | 00a578dba801bdc87f6b3ed896b29691eb2a9380 /numpy | |
parent | 47221fe49a26140a57c6af6569afbaf055db15c8 (diff) | |
download | numpy-859439e5ebcdab8f5a5ced4dd42be537e4b037e5.tar.gz |
Fix record assignment (based on a patch by Sameer DCosta).
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/records.py | 11 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 25 |
2 files changed, 30 insertions, 6 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py index 38d6410d1..ed5b55408 100644 --- a/numpy/core/records.py +++ b/numpy/core/records.py @@ -152,17 +152,16 @@ class record(nt.void): def __setattr__(self, attr, val): if attr in ['setfield', 'getfield', 'dtype']: raise AttributeError, "Cannot set '%s' attribute" % attr - try: - return nt.void.__setattr__(self, attr, val) - except AttributeError: - pass fielddict = nt.void.__getattribute__(self, 'dtype').fields res = fielddict.get(attr, None) if res: return self.setfield(val, *res[:2]) else: - raise AttributeError, "'record' object has no "\ - "attribute '%s'" % attr + if getattr(self,attr,None): + return nt.void.__setattr__(self, attr, val) + else: + raise AttributeError, "'record' object has no "\ + "attribute '%s'" % attr def pprint(self): # pretty-print all fields diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index c1eca7b58..4fc263ada 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -85,5 +85,30 @@ class test_fromrecords(NumpyTestCase): assert_array_equal(ra['field'], [[5,5,5]]) assert callable(ra.field) +class test_record(NumpyTestCase): + def setUp(self): + self.data = rec.fromrecords([(1,2,3),(4,5,6)], + dtype=[("col1", "<i4"), + ("col2", "<i4"), + ("col3", "<i4")]) + + def test_assignment1(self): + a = self.data + assert_equal(a.col1[0],1) + a[0].col1 = 0 + assert_equal(a.col1[0],0) + + def test_assignment2(self): + a = self.data + assert_equal(a.col1[0],1) + a.col1[0] = 0 + assert_equal(a.col1[0],0) + + def test_invalid_assignment(self): + a = self.data + def assign_invalid_column(x): + x[0].col5 = 1 + self.failUnlessRaises(AttributeError,assign_invalid_column,a) + if __name__ == "__main__": NumpyTest().run() |