diff options
author | Travis Oliphant <oliphant@enthought.com> | 2009-07-16 04:29:27 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2009-07-16 04:29:27 +0000 |
commit | c80006f71a025c1428ea381c26a3f16e27ec14b0 (patch) | |
tree | 51014440ac11c923610ea29772a647071453feb0 | |
parent | ec034e7fc2a02c7dd8892b08884d69aca0c99194 (diff) | |
download | numpy-c80006f71a025c1428ea381c26a3f16e27ec14b0.tar.gz |
Fix #728 again. This time don't use max on a partially-ordered set.
-rw-r--r-- | numpy/core/numerictypes.py | 30 | ||||
-rw-r--r-- | numpy/core/tests/test_numerictypes.py | 4 |
2 files changed, 21 insertions, 13 deletions
diff --git a/numpy/core/numerictypes.py b/numpy/core/numerictypes.py index e773c77ae..d329ed70e 100644 --- a/numpy/core/numerictypes.py +++ b/numpy/core/numerictypes.py @@ -656,14 +656,24 @@ def _find_common_coerce(a, b): thisind = __test_types.index(a.char) except ValueError: return None + return _can_coerce_all([a,b], start=thisind) + +# Find a data-type that all data-types in a list can be coerced to +def _can_coerce_all(dtypelist, start=0): + N = len(dtypelist) + if N == 0: + return None + if N == 1: + return dtypelist[0] + thisind = start while thisind < __len_test_types: newdtype = dtype(__test_types[thisind]) - if newdtype >= b and newdtype >= a: + numcoerce = len([x for x in dtypelist if newdtype >= x]) + if numcoerce == N: return newdtype thisind += 1 return None - def find_common_type(array_types, scalar_types): """ Determine common type following standard coercion rules @@ -692,16 +702,14 @@ def find_common_type(array_types, scalar_types): array_types = [dtype(x) for x in array_types] scalar_types = [dtype(x) for x in scalar_types] - if len(scalar_types) == 0: - if len(array_types) == 0: - return None - else: - return max(array_types) - if len(array_types) == 0: - return max(scalar_types) + maxa = _can_coerce_all(array_types) + maxsc = _can_coerce_all(scalar_types) + + if maxa is None: + return maxsc - maxa = max(array_types) - maxsc = max(scalar_types) + if maxsc is None: + return maxa try: index_a = _kind_list.index(maxa.kind) diff --git a/numpy/core/tests/test_numerictypes.py b/numpy/core/tests/test_numerictypes.py index 4e0bb462b..56ed4dbb1 100644 --- a/numpy/core/tests/test_numerictypes.py +++ b/numpy/core/tests/test_numerictypes.py @@ -338,13 +338,13 @@ class TestEmptyField(TestCase): class TestCommonType(TestCase): def test_scalar_loses1(self): - res = np.find_common_type(['f4','f4','i4'],['f8']) + res = np.find_common_type(['f4','f4','i2'],['f8']) assert(res == 'f4') def test_scalar_loses2(self): res = np.find_common_type(['f4','f4'],['i8']) assert(res == 'f4') def test_scalar_wins(self): - res = np.find_common_type(['f4','f4','i4'],['c8']) + res = np.find_common_type(['f4','f4','i2'],['c8']) assert(res == 'c8') def test_scalar_wins2(self): res = np.find_common_type(['u4','i4','i4'],['f4']) |