summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2015-06-22 12:13:03 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2015-06-22 12:13:57 -0400
commit157e396673c4992e97a522dd9f350b480c4cb6c6 (patch)
tree3339b7421095daaf91e3f1585f26d7468fd0e302 /numpy/lib
parenta43e86b0d9d567c7abb9478d5bff90905d3f70ec (diff)
downloadnumpy-157e396673c4992e97a522dd9f350b480c4cb6c6.tar.gz
BUG: np.float16 not recognized in np.common_type
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/tests/test_type_check.py6
-rw-r--r--numpy/lib/type_check.py19
2 files changed, 14 insertions, 11 deletions
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py
index 7afd1206c..f7430c27d 100644
--- a/numpy/lib/tests/test_type_check.py
+++ b/numpy/lib/tests/test_type_check.py
@@ -18,11 +18,13 @@ def assert_all(x):
class TestCommonType(TestCase):
def test_basic(self):
ai32 = np.array([[1, 2], [3, 4]], dtype=np.int32)
+ af16 = np.array([[1, 2], [3, 4]], dtype=np.float16)
af32 = np.array([[1, 2], [3, 4]], dtype=np.float32)
af64 = np.array([[1, 2], [3, 4]], dtype=np.float64)
acs = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.csingle)
acd = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.cdouble)
assert_(common_type(ai32) == np.float64)
+ assert_(common_type(af16) == np.float16)
assert_(common_type(af32) == np.float32)
assert_(common_type(af64) == np.float64)
assert_(common_type(acs) == np.csingle)
@@ -186,7 +188,7 @@ class TestIsnan(TestCase):
class TestIsfinite(TestCase):
- # Fixme, wrong place, isfinite now ufunc
+ # Fixme, wrong place, isfinite now ufunc
def test_goodvalues(self):
z = np.array((-1., 0., 1.))
@@ -217,7 +219,7 @@ class TestIsfinite(TestCase):
class TestIsinf(TestCase):
- # Fixme, wrong place, isinf now ufunc
+ # Fixme, wrong place, isinf now ufunc
def test_goodvalues(self):
z = np.array((-1., 0., 1.))
diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py
index 99677b394..2fe4e7d23 100644
--- a/numpy/lib/type_check.py
+++ b/numpy/lib/type_check.py
@@ -532,14 +532,15 @@ def typename(char):
#-----------------------------------------------------------------------------
#determine the "minimum common type" for a group of arrays.
-array_type = [[_nx.single, _nx.double, _nx.longdouble],
- [_nx.csingle, _nx.cdouble, _nx.clongdouble]]
-array_precision = {_nx.single: 0,
- _nx.double: 1,
- _nx.longdouble: 2,
- _nx.csingle: 0,
- _nx.cdouble: 1,
- _nx.clongdouble: 2}
+array_type = [[_nx.half, _nx.single, _nx.double, _nx.longdouble],
+ [None, _nx.csingle, _nx.cdouble, _nx.clongdouble]]
+array_precision = {_nx.half: 0,
+ _nx.single: 1,
+ _nx.double: 2,
+ _nx.longdouble: 3,
+ _nx.csingle: 1,
+ _nx.cdouble: 2,
+ _nx.clongdouble: 3}
def common_type(*arrays):
"""
Return a scalar type which is common to the input arrays.
@@ -583,7 +584,7 @@ def common_type(*arrays):
if iscomplexobj(a):
is_complex = True
if issubclass(t, _nx.integer):
- p = 1
+ p = 2 # array_precision[_nx.double]
else:
p = array_precision.get(t, None)
if p is None: