summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2009-02-02 05:20:17 +0000
committerpierregm <pierregm@localhost>2009-02-02 05:20:17 +0000
commitf278427bb52fdfdc524d1da4032777ca5290e49e (patch)
treebe3542eb760eee87294e083af574d9600b5f9264 /numpy
parentd3e84d6b104ee2d95e46ffd65d461f5351755a46 (diff)
downloadnumpy-f278427bb52fdfdc524d1da4032777ca5290e49e.tar.gz
* Added a 'autoconvert' option to stack_arrays.
* Fixed 'stack_arrays' to work with fields with titles.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/recfunctions.py21
-rw-r--r--numpy/lib/tests/test_recfunctions.py32
2 files changed, 47 insertions, 6 deletions
diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py
index 4b781c621..b3eecdc0e 100644
--- a/numpy/lib/recfunctions.py
+++ b/numpy/lib/recfunctions.py
@@ -628,7 +628,8 @@ def rec_append_fields(base, names, data, dtypes=None):
-def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False):
+def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
+ autoconvert=False):
"""
Superposes arrays fields by fields
@@ -644,6 +645,8 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False):
asrecarray : {False, True}, optional
Whether to return a recarray (or MaskedRecords if `usemask==True`) or
just a flexible-type ndarray.
+ autoconvert : {False, True}, optional
+ Whether automatically cast the type of the field to the maximum.
Examples
--------
@@ -673,16 +676,24 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False):
#
dtype_l = ndtype[0]
newdescr = dtype_l.descr
- names = list(dtype_l.names or ()) or ['']
+ names = [_[0] for _ in newdescr]
for dtype_n in ndtype[1:]:
for descr in dtype_n.descr:
name = descr[0] or ''
if name not in names:
newdescr.append(descr)
names.append(name)
- elif descr[1] != dict(newdescr)[name]:
- raise TypeError("Incompatible type '%s' <> '%s'" %\
- (dict(newdescr)[name], descr[1]))
+ else:
+ nameidx = names.index(name)
+ current_descr = newdescr[nameidx]
+ if autoconvert:
+ if np.dtype(descr[1]) > np.dtype(current_descr[-1]):
+ current_descr = list(current_descr)
+ current_descr[-1] = descr[1]
+ newdescr[nameidx] = tuple(current_descr)
+ elif descr[1] != current_descr[-1]:
+ raise TypeError("Incompatible type '%s' <> '%s'" %\
+ (dict(newdescr)[name], descr[1]))
# Only one field: use concatenate
if len(newdescr) == 1:
output = ma.concatenate(seqarrays)
diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py
index 1a405090d..f7a72071c 100644
--- a/numpy/lib/tests/test_recfunctions.py
+++ b/numpy/lib/tests/test_recfunctions.py
@@ -485,7 +485,6 @@ class TestStackArrays(TestCase):
assert_equal(test.mask, control.mask)
- #
def test_defaults(self):
"Test defaults: no exception raised if keys of defaults are not fields."
(_, _, _, z) = self.data
@@ -503,6 +502,37 @@ class TestStackArrays(TestCase):
assert_equal(test.mask, control.mask)
+ def test_autoconversion(self):
+ "Tests autoconversion"
+ adtype = [('A', int), ('B', bool), ('C', float)]
+ a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
+ bdtype = [('A', int), ('B', float), ('C', float)]
+ b = ma.array([(4, 5, 6)], dtype=bdtype)
+ control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
+ dtype=bdtype)
+ test = stack_arrays((a, b), autoconvert=True)
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+ try:
+ test = stack_arrays((a, b), autoconvert=False)
+ except TypeError:
+ pass
+ else:
+ raise AssertionError
+
+
+ def test_checktitles(self):
+ "Test using titles in the field names"
+ adtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
+ a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
+ bdtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
+ b = ma.array([(4, 5, 6)], dtype=bdtype)
+ test = stack_arrays((a, b))
+ control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
+ dtype=bdtype)
+ assert_equal(test, control)
+ assert_equal(test.mask, control.mask)
+
class TestJoinBy(TestCase):
#