summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2010-11-09 16:09:14 -0800
committerCharles Harris <charlesr.harris@gmail.com>2010-12-01 20:02:15 -0700
commit88feef8f35cfb30795ed5c02031b69d99827b6f4 (patch)
tree79a7f68dd35603998976112f4bbc437aa08049e3
parent8f354f6208ef14753f2c5988a11536d5918c2c38 (diff)
downloadnumpy-88feef8f35cfb30795ed5c02031b69d99827b6f4.tar.gz
ENH: core: Fix up coercion rules for half/float16
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c1
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src30
-rw-r--r--numpy/core/src/umath/ufunc_object.c7
-rw-r--r--numpy/core/tests/test_half.py27
4 files changed, 46 insertions, 19 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index d2aaf054f..0040e8ad5 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -567,7 +567,6 @@ PyArray_CanCastSafely(int fromtype, int totype)
}
}
}
-
return 0;
}
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 0ac374fd5..c66337602 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -3297,10 +3297,10 @@ initialize_casting_tables(void)
/* Compile-time loop of scalar kinds */
/**begin repeat
* #NAME = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
- * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE#
- * #SCKIND = BOOL, (INTNEG, INTPOS)*5, FLOAT, FLOAT, FLOAT,
- * COMPLEX, COMPLEX, COMPLEX#
+ * #SCKIND = BOOL, (INTNEG, INTPOS)*5, FLOAT*4,
+ * COMPLEX*3#
*/
_npy_scalar_kinds_table[PyArray_@NAME@] = PyArray_@SCKIND@_SCALAR;
/**end repeat**/
@@ -3330,22 +3330,22 @@ initialize_casting_tables(void)
/* Compile-time loop of casting rules */
/**begin repeat
* #FROM_NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
- * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE#
* #FROM_BASENAME = BYTE, BYTE, SHORT, SHORT, INT, INT, LONG, LONG,
- * LONGLONG, LONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * LONGLONG, LONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
* FLOAT, DOUBLE, LONGDOUBLE#
* #from_isint = 1, 0, 1, 0, 1, 0, 1, 0,
- * 1, 0, 0, 0, 0,
+ * 1, 0, 0, 0, 0, 0,
* 0, 0, 0#
* #from_isuint = 0, 1, 0, 1, 0, 1, 0, 1,
- * 0, 1, 0, 0, 0,
+ * 0, 1, 0, 0, 0, 0,
* 0, 0, 0#
* #from_isfloat = 0, 0, 0, 0, 0, 0, 0, 0,
- * 0, 0, 1, 1, 1,
+ * 0, 0, 1, 1, 1, 1,
* 0, 0, 0#
* #from_iscomplex = 0, 0, 0, 0, 0, 0, 0, 0,
- * 0, 0, 0, 0, 0,
+ * 0, 0, 0, 0, 0, 0,
* 1, 1, 1#
*/
#define _FROM_BSIZE NPY_SIZEOF_@FROM_BASENAME@
@@ -3356,22 +3356,22 @@ initialize_casting_tables(void)
/**begin repeat1
* #TO_NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
- * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * LONGLONG, ULONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE#
* #TO_BASENAME = BYTE, BYTE, SHORT, SHORT, INT, INT, LONG, LONG,
- * LONGLONG, LONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * LONGLONG, LONGLONG, HALF, FLOAT, DOUBLE, LONGDOUBLE,
* FLOAT, DOUBLE, LONGDOUBLE#
* #to_isint = 1, 0, 1, 0, 1, 0, 1, 0,
- * 1, 0, 0, 0, 0,
+ * 1, 0, 0, 0, 0, 0,
* 0, 0, 0#
* #to_isuint = 0, 1, 0, 1, 0, 1, 0, 1,
- * 0, 1, 0, 0, 0,
+ * 0, 1, 0, 0, 0, 0,
* 0, 0, 0#
* #to_isfloat = 0, 0, 0, 0, 0, 0, 0, 0,
- * 0, 0, 1, 1, 1,
+ * 0, 0, 1, 1, 1, 1,
* 0, 0, 0#
* #to_iscomplex = 0, 0, 0, 0, 0, 0, 0, 0,
- * 0, 0, 0, 0, 0,
+ * 0, 0, 0, 0, 0, 0,
* 1, 1, 1#
*/
#define _TO_BSIZE NPY_SIZEOF_@TO_BASENAME@
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index be38a9f10..39b04db73 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -224,10 +224,11 @@ _lowest_type(char intype)
case PyArray_ULONG:
case PyArray_ULONGLONG:
return PyArray_UBYTE;
- /* case PyArray_FLOAT:*/
+ /* case PyArray_HALF: */
+ case PyArray_FLOAT:
case PyArray_DOUBLE:
case PyArray_LONGDOUBLE:
- return PyArray_FLOAT;
+ return PyArray_HALF;
/* case PyArray_CFLOAT:*/
case PyArray_CDOUBLE:
case PyArray_CLONGDOUBLE:
@@ -3328,7 +3329,7 @@ PyUFunc_GenericReduction(PyUFuncObject *self, PyObject *args,
* is used for add and multiply reduction to avoid overflow
*/
int typenum = PyArray_TYPE(mp);
- if ((typenum < NPY_FLOAT)
+ if ((typenum < NPY_HALF)
&& ((strcmp(self->name,"add") == 0)
|| (strcmp(self->name,"multiply") == 0))) {
if (PyTypeNum_ISBOOL(typenum)) {
diff --git a/numpy/core/tests/test_half.py b/numpy/core/tests/test_half.py
index 4d61bc28f..5a547d35e 100644
--- a/numpy/core/tests/test_half.py
+++ b/numpy/core/tests/test_half.py
@@ -276,3 +276,30 @@ def test_half_ufuncs():
assert_equal(np.frexp(b), ([-0.5,0.625,0.5,0.5,0.75],[2,3,1,3,2]))
assert_equal(np.ldexp(b,[0,1,2,4,2]), [-2,10,4,64,12])
+def test_half_coercion():
+ """Test that half gets coerced properly with the other types"""
+ a16 = np.array((1,),dtype=float16)
+ a32 = np.array((1,),dtype=float32)
+ b16 = float16(1)
+ b32 = float32(1)
+
+ assert_equal(np.power(a16,2).dtype, float16)
+ assert_equal(np.power(a16,2.0).dtype, float16)
+ assert_equal(np.power(a16,b16).dtype, float16)
+ assert_equal(np.power(a16,b32).dtype, float16)
+ assert_equal(np.power(a16,a16).dtype, float16)
+ assert_equal(np.power(a16,a32).dtype, float32)
+
+ assert_equal(np.power(b16,2).dtype, float64)
+ assert_equal(np.power(b16,2.0).dtype, float64)
+ assert_equal(np.power(b16,b16).dtype, float16)
+ assert_equal(np.power(b16,b32).dtype, float32)
+ assert_equal(np.power(b16,a16).dtype, float16)
+ assert_equal(np.power(b16,a32).dtype, float32)
+
+ assert_equal(np.power(a32,a16).dtype, float32)
+ assert_equal(np.power(a32,b16).dtype, float32)
+ assert_equal(np.power(b32,a16).dtype, float16)
+ assert_equal(np.power(b32,b16).dtype, float32)
+
+