diff options
-rw-r--r-- | numpy/lib/recfunctions.py | 21 | ||||
-rw-r--r-- | numpy/lib/tests/test_recfunctions.py | 32 |
2 files changed, 47 insertions, 6 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index 4b781c621..b3eecdc0e 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -628,7 +628,8 @@ def rec_append_fields(base, names, data, dtypes=None): -def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False): +def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False, + autoconvert=False): """ Superposes arrays fields by fields @@ -644,6 +645,8 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False): asrecarray : {False, True}, optional Whether to return a recarray (or MaskedRecords if `usemask==True`) or just a flexible-type ndarray. + autoconvert : {False, True}, optional + Whether automatically cast the type of the field to the maximum. Examples -------- @@ -673,16 +676,24 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False): # dtype_l = ndtype[0] newdescr = dtype_l.descr - names = list(dtype_l.names or ()) or [''] + names = [_[0] for _ in newdescr] for dtype_n in ndtype[1:]: for descr in dtype_n.descr: name = descr[0] or '' if name not in names: newdescr.append(descr) names.append(name) - elif descr[1] != dict(newdescr)[name]: - raise TypeError("Incompatible type '%s' <> '%s'" %\ - (dict(newdescr)[name], descr[1])) + else: + nameidx = names.index(name) + current_descr = newdescr[nameidx] + if autoconvert: + if np.dtype(descr[1]) > np.dtype(current_descr[-1]): + current_descr = list(current_descr) + current_descr[-1] = descr[1] + newdescr[nameidx] = tuple(current_descr) + elif descr[1] != current_descr[-1]: + raise TypeError("Incompatible type '%s' <> '%s'" %\ + (dict(newdescr)[name], descr[1])) # Only one field: use concatenate if len(newdescr) == 1: output = ma.concatenate(seqarrays) diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py index 1a405090d..f7a72071c 100644 --- a/numpy/lib/tests/test_recfunctions.py +++ b/numpy/lib/tests/test_recfunctions.py @@ -485,7 +485,6 @@ class TestStackArrays(TestCase): assert_equal(test.mask, control.mask) - # def test_defaults(self): "Test defaults: no exception raised if keys of defaults are not fields." (_, _, _, z) = self.data @@ -503,6 +502,37 @@ class TestStackArrays(TestCase): assert_equal(test.mask, control.mask) + def test_autoconversion(self): + "Tests autoconversion" + adtype = [('A', int), ('B', bool), ('C', float)] + a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype) + bdtype = [('A', int), ('B', float), ('C', float)] + b = ma.array([(4, 5, 6)], dtype=bdtype) + control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)], + dtype=bdtype) + test = stack_arrays((a, b), autoconvert=True) + assert_equal(test, control) + assert_equal(test.mask, control.mask) + try: + test = stack_arrays((a, b), autoconvert=False) + except TypeError: + pass + else: + raise AssertionError + + + def test_checktitles(self): + "Test using titles in the field names" + adtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)] + a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype) + bdtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)] + b = ma.array([(4, 5, 6)], dtype=bdtype) + test = stack_arrays((a, b)) + control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)], + dtype=bdtype) + assert_equal(test, control) + assert_equal(test.mask, control.mask) + class TestJoinBy(TestCase): # |