summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2009-07-16 04:29:27 +0000
committerTravis Oliphant <oliphant@enthought.com>2009-07-16 04:29:27 +0000
commitc80006f71a025c1428ea381c26a3f16e27ec14b0 (patch)
tree51014440ac11c923610ea29772a647071453feb0
parentec034e7fc2a02c7dd8892b08884d69aca0c99194 (diff)
downloadnumpy-c80006f71a025c1428ea381c26a3f16e27ec14b0.tar.gz
Fix #728 again. This time don't use max on a partially-ordered set.
-rw-r--r--numpy/core/numerictypes.py30
-rw-r--r--numpy/core/tests/test_numerictypes.py4
2 files changed, 21 insertions, 13 deletions
diff --git a/numpy/core/numerictypes.py b/numpy/core/numerictypes.py
index e773c77ae..d329ed70e 100644
--- a/numpy/core/numerictypes.py
+++ b/numpy/core/numerictypes.py
@@ -656,14 +656,24 @@ def _find_common_coerce(a, b):
thisind = __test_types.index(a.char)
except ValueError:
return None
+ return _can_coerce_all([a,b], start=thisind)
+
+# Find a data-type that all data-types in a list can be coerced to
+def _can_coerce_all(dtypelist, start=0):
+ N = len(dtypelist)
+ if N == 0:
+ return None
+ if N == 1:
+ return dtypelist[0]
+ thisind = start
while thisind < __len_test_types:
newdtype = dtype(__test_types[thisind])
- if newdtype >= b and newdtype >= a:
+ numcoerce = len([x for x in dtypelist if newdtype >= x])
+ if numcoerce == N:
return newdtype
thisind += 1
return None
-
def find_common_type(array_types, scalar_types):
"""
Determine common type following standard coercion rules
@@ -692,16 +702,14 @@ def find_common_type(array_types, scalar_types):
array_types = [dtype(x) for x in array_types]
scalar_types = [dtype(x) for x in scalar_types]
- if len(scalar_types) == 0:
- if len(array_types) == 0:
- return None
- else:
- return max(array_types)
- if len(array_types) == 0:
- return max(scalar_types)
+ maxa = _can_coerce_all(array_types)
+ maxsc = _can_coerce_all(scalar_types)
+
+ if maxa is None:
+ return maxsc
- maxa = max(array_types)
- maxsc = max(scalar_types)
+ if maxsc is None:
+ return maxa
try:
index_a = _kind_list.index(maxa.kind)
diff --git a/numpy/core/tests/test_numerictypes.py b/numpy/core/tests/test_numerictypes.py
index 4e0bb462b..56ed4dbb1 100644
--- a/numpy/core/tests/test_numerictypes.py
+++ b/numpy/core/tests/test_numerictypes.py
@@ -338,13 +338,13 @@ class TestEmptyField(TestCase):
class TestCommonType(TestCase):
def test_scalar_loses1(self):
- res = np.find_common_type(['f4','f4','i4'],['f8'])
+ res = np.find_common_type(['f4','f4','i2'],['f8'])
assert(res == 'f4')
def test_scalar_loses2(self):
res = np.find_common_type(['f4','f4'],['i8'])
assert(res == 'f4')
def test_scalar_wins(self):
- res = np.find_common_type(['f4','f4','i4'],['c8'])
+ res = np.find_common_type(['f4','f4','i2'],['c8'])
assert(res == 'c8')
def test_scalar_wins2(self):
res = np.find_common_type(['u4','i4','i4'],['f4'])