diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2010-11-09 16:09:14 -0800 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2010-12-01 20:02:15 -0700 |
commit | 88feef8f35cfb30795ed5c02031b69d99827b6f4 (patch) | |
tree | 79a7f68dd35603998976112f4bbc437aa08049e3 | |
parent | 8f354f6208ef14753f2c5988a11536d5918c2c38 (diff) | |
download | numpy-88feef8f35cfb30795ed5c02031b69d99827b6f4.tar.gz |
ENH: core: Fix up coercion rules for half/float16
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 1 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 30 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_half.py | 27 |
4 files changed, 46 insertions, 19 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index d2aaf054f..0040e8ad5 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -567,7 +567,6 @@ PyArray_CanCastSafely(int fromtype, int totype) } } } - return 0; } diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 0ac374fd5..c66337602 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -3297,10 +3297,10 @@ initialize_casting_tables(void) /* Compile-time loop of scalar kinds */ /**begin repeat * #NAME = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, - * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE, + * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, * CFLOAT, CDOUBLE, CLONGDOUBLE# - * #SCKIND = BOOL, (INTNEG, INTPOS)*5, FLOAT, FLOAT, FLOAT, - * COMPLEX, COMPLEX, COMPLEX# + * #SCKIND = BOOL, (INTNEG, INTPOS)*5, FLOAT*4, + * COMPLEX*3# */ _npy_scalar_kinds_table[PyArray_@NAME@] = PyArray_@SCKIND@_SCALAR; /**end repeat**/ @@ -3330,22 +3330,22 @@ initialize_casting_tables(void) /* Compile-time loop of casting rules */ /**begin repeat * #FROM_NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, - * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE, + * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, * CFLOAT, CDOUBLE, CLONGDOUBLE# * #FROM_BASENAME = BYTE, BYTE, SHORT, SHORT, INT, INT, LONG, LONG, - * LONGLONG, LONGLONG, FLOAT, DOUBLE, LONGDOUBLE, + * LONGLONG, LONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, * FLOAT, DOUBLE, LONGDOUBLE# * #from_isint = 1, 0, 1, 0, 1, 0, 1, 0, - * 1, 0, 0, 0, 0, + * 1, 0, 0, 0, 0, 0, * 0, 0, 0# * #from_isuint = 0, 1, 0, 1, 0, 1, 0, 1, - * 0, 1, 0, 0, 0, + * 0, 1, 0, 0, 0, 0, * 0, 0, 0# * #from_isfloat = 0, 0, 0, 0, 0, 0, 0, 0, - * 0, 0, 1, 1, 1, + * 0, 0, 1, 1, 1, 1, * 0, 0, 0# * #from_iscomplex = 0, 0, 0, 0, 0, 0, 0, 0, - * 0, 0, 0, 0, 0, + * 0, 0, 0, 0, 0, 0, * 1, 1, 1# */ #define _FROM_BSIZE NPY_SIZEOF_@FROM_BASENAME@ @@ -3356,22 +3356,22 @@ initialize_casting_tables(void) /**begin repeat1 * #TO_NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, - * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE, + * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, * CFLOAT, CDOUBLE, CLONGDOUBLE# * #TO_BASENAME = BYTE, BYTE, SHORT, SHORT, INT, INT, LONG, LONG, - * LONGLONG, LONGLONG, FLOAT, DOUBLE, LONGDOUBLE, + * LONGLONG, LONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE, * FLOAT, DOUBLE, LONGDOUBLE# * #to_isint = 1, 0, 1, 0, 1, 0, 1, 0, - * 1, 0, 0, 0, 0, + * 1, 0, 0, 0, 0, 0, * 0, 0, 0# * #to_isuint = 0, 1, 0, 1, 0, 1, 0, 1, - * 0, 1, 0, 0, 0, + * 0, 1, 0, 0, 0, 0, * 0, 0, 0# * #to_isfloat = 0, 0, 0, 0, 0, 0, 0, 0, - * 0, 0, 1, 1, 1, + * 0, 0, 1, 1, 1, 1, * 0, 0, 0# * #to_iscomplex = 0, 0, 0, 0, 0, 0, 0, 0, - * 0, 0, 0, 0, 0, + * 0, 0, 0, 0, 0, 0, * 1, 1, 1# */ #define _TO_BSIZE NPY_SIZEOF_@TO_BASENAME@ diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index be38a9f10..39b04db73 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -224,10 +224,11 @@ _lowest_type(char intype) case PyArray_ULONG: case PyArray_ULONGLONG: return PyArray_UBYTE; - /* case PyArray_FLOAT:*/ + /* case PyArray_HALF: */ + case PyArray_FLOAT: case PyArray_DOUBLE: case PyArray_LONGDOUBLE: - return PyArray_FLOAT; + return PyArray_HALF; /* case PyArray_CFLOAT:*/ case PyArray_CDOUBLE: case PyArray_CLONGDOUBLE: @@ -3328,7 +3329,7 @@ PyUFunc_GenericReduction(PyUFuncObject *self, PyObject *args, * is used for add and multiply reduction to avoid overflow */ int typenum = PyArray_TYPE(mp); - if ((typenum < NPY_FLOAT) + if ((typenum < NPY_HALF) && ((strcmp(self->name,"add") == 0) || (strcmp(self->name,"multiply") == 0))) { if (PyTypeNum_ISBOOL(typenum)) { diff --git a/numpy/core/tests/test_half.py b/numpy/core/tests/test_half.py index 4d61bc28f..5a547d35e 100644 --- a/numpy/core/tests/test_half.py +++ b/numpy/core/tests/test_half.py @@ -276,3 +276,30 @@ def test_half_ufuncs(): assert_equal(np.frexp(b), ([-0.5,0.625,0.5,0.5,0.75],[2,3,1,3,2])) assert_equal(np.ldexp(b,[0,1,2,4,2]), [-2,10,4,64,12]) +def test_half_coercion(): + """Test that half gets coerced properly with the other types""" + a16 = np.array((1,),dtype=float16) + a32 = np.array((1,),dtype=float32) + b16 = float16(1) + b32 = float32(1) + + assert_equal(np.power(a16,2).dtype, float16) + assert_equal(np.power(a16,2.0).dtype, float16) + assert_equal(np.power(a16,b16).dtype, float16) + assert_equal(np.power(a16,b32).dtype, float16) + assert_equal(np.power(a16,a16).dtype, float16) + assert_equal(np.power(a16,a32).dtype, float32) + + assert_equal(np.power(b16,2).dtype, float64) + assert_equal(np.power(b16,2.0).dtype, float64) + assert_equal(np.power(b16,b16).dtype, float16) + assert_equal(np.power(b16,b32).dtype, float32) + assert_equal(np.power(b16,a16).dtype, float16) + assert_equal(np.power(b16,a32).dtype, float32) + + assert_equal(np.power(a32,a16).dtype, float32) + assert_equal(np.power(a32,b16).dtype, float32) + assert_equal(np.power(b32,a16).dtype, float16) + assert_equal(np.power(b32,b16).dtype, float32) + + |