diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2015-06-22 12:13:03 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2015-06-22 12:13:57 -0400 |
commit | 157e396673c4992e97a522dd9f350b480c4cb6c6 (patch) | |
tree | 3339b7421095daaf91e3f1585f26d7468fd0e302 /numpy/lib | |
parent | a43e86b0d9d567c7abb9478d5bff90905d3f70ec (diff) | |
download | numpy-157e396673c4992e97a522dd9f350b480c4cb6c6.tar.gz |
BUG: np.float16 not recognized in np.common_type
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/tests/test_type_check.py | 6 | ||||
-rw-r--r-- | numpy/lib/type_check.py | 19 |
2 files changed, 14 insertions, 11 deletions
diff --git a/numpy/lib/tests/test_type_check.py b/numpy/lib/tests/test_type_check.py index 7afd1206c..f7430c27d 100644 --- a/numpy/lib/tests/test_type_check.py +++ b/numpy/lib/tests/test_type_check.py @@ -18,11 +18,13 @@ def assert_all(x): class TestCommonType(TestCase): def test_basic(self): ai32 = np.array([[1, 2], [3, 4]], dtype=np.int32) + af16 = np.array([[1, 2], [3, 4]], dtype=np.float16) af32 = np.array([[1, 2], [3, 4]], dtype=np.float32) af64 = np.array([[1, 2], [3, 4]], dtype=np.float64) acs = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.csingle) acd = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.cdouble) assert_(common_type(ai32) == np.float64) + assert_(common_type(af16) == np.float16) assert_(common_type(af32) == np.float32) assert_(common_type(af64) == np.float64) assert_(common_type(acs) == np.csingle) @@ -186,7 +188,7 @@ class TestIsnan(TestCase): class TestIsfinite(TestCase): - # Fixme, wrong place, isfinite now ufunc + # Fixme, wrong place, isfinite now ufunc def test_goodvalues(self): z = np.array((-1., 0., 1.)) @@ -217,7 +219,7 @@ class TestIsfinite(TestCase): class TestIsinf(TestCase): - # Fixme, wrong place, isinf now ufunc + # Fixme, wrong place, isinf now ufunc def test_goodvalues(self): z = np.array((-1., 0., 1.)) diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index 99677b394..2fe4e7d23 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -532,14 +532,15 @@ def typename(char): #----------------------------------------------------------------------------- #determine the "minimum common type" for a group of arrays. -array_type = [[_nx.single, _nx.double, _nx.longdouble], - [_nx.csingle, _nx.cdouble, _nx.clongdouble]] -array_precision = {_nx.single: 0, - _nx.double: 1, - _nx.longdouble: 2, - _nx.csingle: 0, - _nx.cdouble: 1, - _nx.clongdouble: 2} +array_type = [[_nx.half, _nx.single, _nx.double, _nx.longdouble], + [None, _nx.csingle, _nx.cdouble, _nx.clongdouble]] +array_precision = {_nx.half: 0, + _nx.single: 1, + _nx.double: 2, + _nx.longdouble: 3, + _nx.csingle: 1, + _nx.cdouble: 2, + _nx.clongdouble: 3} def common_type(*arrays): """ Return a scalar type which is common to the input arrays. @@ -583,7 +584,7 @@ def common_type(*arrays): if iscomplexobj(a): is_complex = True if issubclass(t, _nx.integer): - p = 1 + p = 2 # array_precision[_nx.double] else: p = array_precision.get(t, None) if p is None: |