summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2010-11-28 17:00:36 -0800
committerMark Wiebe <mwwiebe@gmail.com>2010-11-30 09:10:21 -0800
commitaf84876fac13ac2e4e44ac0cae599fe9d6e68643 (patch)
tree516a1dbdb03df1eb4fde2a7e1dd166a974cdb1a3
parent9273a6139e7b797244b5b88fb371059cc7c3ea3a (diff)
downloadnumpy-af84876fac13ac2e4e44ac0cae599fe9d6e68643.tar.gz
ENH: Remove type number ordering assumptions in CanCastSafely, ScalarKinds, and CanCoerceScalar
Also add print_coercion_tables.py to aid when refactoring type casting/coercion/promotion.
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c150
-rw-r--r--numpy/core/src/multiarray/getset.c22
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c109
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src187
-rw-r--r--numpy/core/src/multiarray/scalartypes.h16
-rw-r--r--numpy/core/src/multiarray/usertypes.c13
-rw-r--r--numpy/core/tests/test_numeric.py43
-rwxr-xr-xnumpy/testing/print_coercion_tables.py79
8 files changed, 444 insertions, 175 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index f93a45717..d2aaf054f 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -522,145 +522,53 @@ PyArray_CastAnyTo(PyArrayObject *out, PyArrayObject *mp)
NPY_NO_EXPORT int
PyArray_CanCastSafely(int fromtype, int totype)
{
- PyArray_Descr *from, *to;
- int felsize, telsize;
+ PyArray_Descr *from;
- if (fromtype == totype) {
- return 1;
+ /* Fast table lookup for small type numbers */
+ if ((unsigned int)fromtype < NPY_NTYPES && (unsigned int)totype < NPY_NTYPES) {
+ return _npy_can_cast_safely_table[fromtype][totype];
}
- if (fromtype == PyArray_BOOL) {
+
+ /* Identity */
+ if (fromtype == totype) {
return 1;
}
- if (totype == PyArray_BOOL) {
- return 0;
- }
- if (fromtype == PyArray_DATETIME || fromtype == PyArray_TIMEDELTA ||
- totype == PyArray_DATETIME || totype == PyArray_TIMEDELTA) {
- return 0;
- }
- if (totype == PyArray_OBJECT || totype == PyArray_VOID) {
- return 1;
+ /* Special-cases for some types */
+ switch (fromtype) {
+ case PyArray_DATETIME:
+ case PyArray_TIMEDELTA:
+ case PyArray_OBJECT:
+ case PyArray_VOID:
+ return 0;
+ case PyArray_BOOL:
+ return 1;
}
- if (fromtype == PyArray_OBJECT || fromtype == PyArray_VOID) {
- return 0;
+ switch (totype) {
+ case PyArray_BOOL:
+ case PyArray_DATETIME:
+ case PyArray_TIMEDELTA:
+ return 0;
+ case PyArray_OBJECT:
+ case PyArray_VOID:
+ return 1;
}
+
from = PyArray_DescrFromType(fromtype);
/*
* cancastto is a PyArray_NOTYPE terminated C-int-array of types that
* the data-type can be cast to safely.
*/
if (from->f->cancastto) {
- int *curtype;
- curtype = from->f->cancastto;
+ int *curtype = from->f->cancastto;
+
while (*curtype != PyArray_NOTYPE) {
if (*curtype++ == totype) {
return 1;
}
}
}
- if (PyTypeNum_ISUSERDEF(totype)) {
- return 0;
- }
- to = PyArray_DescrFromType(totype);
- telsize = to->elsize;
- felsize = from->elsize;
- Py_DECREF(from);
- Py_DECREF(to);
-
- switch(fromtype) {
- case PyArray_BYTE:
- case PyArray_SHORT:
- case PyArray_INT:
- case PyArray_LONG:
- case PyArray_LONGLONG:
- if (PyTypeNum_ISINTEGER(totype)) {
- if (PyTypeNum_ISUNSIGNED(totype)) {
- return 0;
- }
- else {
- return telsize >= felsize;
- }
- }
- else if (PyTypeNum_ISFLOAT(totype)) {
- if (felsize < 8) {
- return telsize > felsize;
- }
- else {
- return telsize >= felsize;
- }
- }
- else if (PyTypeNum_ISCOMPLEX(totype)) {
- if (felsize < 8) {
- return (telsize >> 1) > felsize;
- }
- else {
- return (telsize >> 1) >= felsize;
- }
- }
- else {
- return totype > fromtype;
- }
- case PyArray_UBYTE:
- case PyArray_USHORT:
- case PyArray_UINT:
- case PyArray_ULONG:
- case PyArray_ULONGLONG:
- if (PyTypeNum_ISINTEGER(totype)) {
- if (PyTypeNum_ISSIGNED(totype)) {
- return telsize > felsize;
- }
- else {
- return telsize >= felsize;
- }
- }
- else if (PyTypeNum_ISFLOAT(totype)) {
- if (felsize < 8) {
- return telsize > felsize;
- }
- else {
- return telsize >= felsize;
- }
- }
- else if (PyTypeNum_ISCOMPLEX(totype)) {
- if (felsize < 8) {
- return (telsize >> 1) > felsize;
- }
- else {
- return (telsize >> 1) >= felsize;
- }
- }
- else {
- return totype > fromtype;
- }
- case PyArray_FLOAT:
- case PyArray_DOUBLE:
- case PyArray_LONGDOUBLE:
- if (PyTypeNum_ISCOMPLEX(totype)) {
- return (telsize >> 1) >= felsize;
- }
- else if (PyTypeNum_ISFLOAT(totype) && (telsize == felsize)) {
- /* On some systems, double == longdouble */
- return 1;
- }
- else {
- return totype > fromtype;
- }
- case PyArray_CFLOAT:
- case PyArray_CDOUBLE:
- case PyArray_CLONGDOUBLE:
- if (PyTypeNum_ISCOMPLEX(totype) && (telsize == felsize)) {
- /* On some systems, double == longdouble */
- return 1;
- }
- else {
- return totype > fromtype;
- }
- case PyArray_STRING:
- case PyArray_UNICODE:
- return totype > fromtype;
- default:
- return 0;
- }
+
+ return 0;
}
/*NUMPY_API
diff --git a/numpy/core/src/multiarray/getset.c b/numpy/core/src/multiarray/getset.c
index b35058238..a636383a6 100644
--- a/numpy/core/src/multiarray/getset.c
+++ b/numpy/core/src/multiarray/getset.c
@@ -576,12 +576,30 @@ array_base_get(PyArrayObject *self)
static PyArrayObject *
_get_part(PyArrayObject *self, int imag)
{
+ int float_type_num;
PyArray_Descr *type;
PyArrayObject *ret;
int offset;
- type = PyArray_DescrFromType(self->descr->type_num -
- PyArray_NUM_FLOATTYPE);
+ switch (self->descr->type_num) {
+ case PyArray_CFLOAT:
+ float_type_num = PyArray_FLOAT;
+ break;
+ case PyArray_CDOUBLE:
+ float_type_num = PyArray_DOUBLE;
+ break;
+ case PyArray_CLONGDOUBLE:
+ float_type_num = PyArray_LONGDOUBLE;
+ break;
+ default:
+ PyErr_Format(PyExc_ValueError,
+ "Cannot convert complex type number %d to float",
+ self->descr->type_num);
+ return NULL;
+
+ }
+ type = PyArray_DescrFromType(float_type_num);
+
offset = (imag ? type->elsize : 0);
if (!PyArray_ISNBO(self->descr->byteorder)) {
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 60642ccd3..f35bfd662 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -480,48 +480,44 @@ _signbit_set(PyArrayObject *arr)
/*NUMPY_API
* ScalarKind
+ *
+ * Returns the scalar kind of a type number, with an
+ * optional tweak based on the scalar value itself.
+ * If no scalar is provided, it returns INTPOS_SCALAR
+ * for both signed and unsigned integers, otherwise
+ * it checks the sign of any signed integer to choose
+ * INTNEG_SCALAR when appropriate.
*/
NPY_NO_EXPORT NPY_SCALARKIND
PyArray_ScalarKind(int typenum, PyArrayObject **arr)
{
- if (PyTypeNum_ISSIGNED(typenum)) {
- if (arr && _signbit_set(*arr)) {
- return PyArray_INTNEG_SCALAR;
- }
- else {
- return PyArray_INTPOS_SCALAR;
- }
- }
- if (PyTypeNum_ISFLOAT(typenum)) {
- return PyArray_FLOAT_SCALAR;
- }
- if (PyTypeNum_ISUNSIGNED(typenum)) {
- return PyArray_INTPOS_SCALAR;
- }
- if (PyTypeNum_ISCOMPLEX(typenum)) {
- return PyArray_COMPLEX_SCALAR;
- }
- if (PyTypeNum_ISBOOL(typenum)) {
- return PyArray_BOOL_SCALAR;
- }
+ NPY_SCALARKIND ret = PyArray_NOSCALAR;
- if (PyTypeNum_ISUSERDEF(typenum)) {
- NPY_SCALARKIND retval;
+ if ((unsigned int)typenum < NPY_NTYPES) {
+ ret = _npy_scalar_kinds_table[typenum];
+ /* Signed integer types are INTNEG in the table */
+ if (ret == PyArray_INTNEG_SCALAR) {
+ if (!arr || !_signbit_set(*arr)) {
+ ret = PyArray_INTPOS_SCALAR;
+ }
+ }
+ } else if (PyTypeNum_ISUSERDEF(typenum)) {
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
if (descr->f->scalarkind) {
- retval = descr->f->scalarkind((arr ? *arr : NULL));
- }
- else {
- retval = PyArray_NOSCALAR;
+ ret = descr->f->scalarkind((arr ? *arr : NULL));
}
Py_DECREF(descr);
- return retval;
}
- return PyArray_OBJECT_SCALAR;
+
+ return ret;
}
-/*NUMPY_API*/
+/*NUMPY_API
+ *
+ * Determines whether the data type 'thistype', with
+ * scalar kind 'scalar', can be coerced into 'neededtype'.
+ */
NPY_NO_EXPORT int
PyArray_CanCoerceScalar(int thistype, int neededtype,
NPY_SCALARKIND scalar)
@@ -529,9 +525,39 @@ PyArray_CanCoerceScalar(int thistype, int neededtype,
PyArray_Descr* from;
int *castlist;
+ /* If 'thistype' is not a scalar, it must be safely castable */
if (scalar == PyArray_NOSCALAR) {
return PyArray_CanCastSafely(thistype, neededtype);
}
+ if ((unsigned int)neededtype < NPY_NTYPES) {
+ NPY_SCALARKIND neededscalar;
+
+ if (scalar == PyArray_OBJECT_SCALAR) {
+ return PyArray_CanCastSafely(thistype, neededtype);
+ }
+
+ /*
+ * The lookup table gives us exactly what we need for
+ * this comparison, which PyArray_ScalarKind would not.
+ *
+ * The rule is that positive scalars can be coerced
+ * to a signed ints, but negative scalars cannot be coerced
+ * to unsigned ints.
+ * _npy_scalar_kinds_table[int]==NEGINT > POSINT,
+ * so 1 is returned, but
+ * _npy_scalar_kinds_table[uint]==POSINT < NEGINT,
+ * so 0 is returned, as required.
+ *
+ */
+ neededscalar = _npy_scalar_kinds_table[neededtype];
+ if (neededscalar >= scalar) {
+ return 1;
+ }
+ if (!PyTypeNum_ISUSERDEF(thistype)) {
+ return 0;
+ }
+ }
+
from = PyArray_DescrFromType(thistype);
if (from->f->cancastscalarkindto
&& (castlist = from->f->cancastscalarkindto[scalar])) {
@@ -544,29 +570,7 @@ PyArray_CanCoerceScalar(int thistype, int neededtype,
}
Py_DECREF(from);
- switch(scalar) {
- case PyArray_BOOL_SCALAR:
- case PyArray_OBJECT_SCALAR:
- return PyArray_CanCastSafely(thistype, neededtype);
- default:
- if (PyTypeNum_ISUSERDEF(neededtype)) {
- return FALSE;
- }
- switch(scalar) {
- case PyArray_INTPOS_SCALAR:
- return (neededtype >= PyArray_BYTE);
- case PyArray_INTNEG_SCALAR:
- return (neededtype >= PyArray_BYTE)
- && !(PyTypeNum_ISUNSIGNED(neededtype));
- case PyArray_FLOAT_SCALAR:
- return (neededtype >= PyArray_FLOAT);
- case PyArray_COMPLEX_SCALAR:
- return (neededtype >= PyArray_CFLOAT);
- default:
- /* should never get here... */
- return 1;
- }
- }
+ return 0;
}
/*
@@ -2821,6 +2825,7 @@ static struct PyMethodDef array_module_methods[] = {
static int
setup_scalartypes(PyObject *NPY_UNUSED(dict))
{
+ initialize_casting_tables();
initialize_numeric_types();
if (PyType_Ready(&PyBool_Type) < 0) {
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index 3cc3e3d19..b2e8690c7 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -3248,6 +3248,193 @@ NPY_NO_EXPORT PyTypeObject Py@NAME@ArrType_Type = {
/**end repeat**/
+#ifdef NPY_ENABLE_SEPARATE_COMPILATION
+/*
+ * This table maps the built-in type numbers to their scalar
+ * type numbers. Note that signed integers are mapped to INTNEG_SCALAR,
+ * which is different than what PyArray_ScalarKind returns.
+ */
+NPY_NO_EXPORT char
+_npy_scalar_kinds_table[NPY_NTYPES];
+/*
+ * This table describes safe casting for small type numbers,
+ * and is used by PyArray_CanCastSafely.
+ */
+NPY_NO_EXPORT unsigned char
+_npy_can_cast_safely_table[NPY_NTYPES][NPY_NTYPES];
+#endif
+
+NPY_NO_EXPORT void
+initialize_casting_tables(void)
+{
+ int i;
+
+ /* Default for built-in types is object scalar */
+ memset(_npy_scalar_kinds_table, PyArray_OBJECT_SCALAR,
+ sizeof(_npy_scalar_kinds_table));
+
+ /* Compile-time loop of scalar kinds */
+/**begin repeat
+ * #NAME = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
+ * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
+ * #SCKIND = BOOL, (INTNEG, INTPOS)*5, FLOAT, FLOAT, FLOAT,
+ * COMPLEX, COMPLEX, COMPLEX#
+ */
+ _npy_scalar_kinds_table[PyArray_@NAME@] = PyArray_@SCKIND@_SCALAR;
+/**end repeat**/
+
+ memset(_npy_can_cast_safely_table, 0, sizeof(_npy_can_cast_safely_table));
+
+ for (i = 0; i < NPY_NTYPES; ++i) {
+ /* Identity */
+ _npy_can_cast_safely_table[i][i] = 1;
+ /* Bool -> <Anything> */
+ _npy_can_cast_safely_table[PyArray_BOOL][i] = 1;
+ /* DateTime sits out for these... */
+ if (i != PyArray_DATETIME && i != PyArray_TIMEDELTA) {
+ /* <Anything> -> Object */
+ _npy_can_cast_safely_table[i][PyArray_OBJECT] = 1;
+ /* <Anything> -> Void */
+ _npy_can_cast_safely_table[i][PyArray_VOID] = 1;
+ }
+ }
+
+ _npy_can_cast_safely_table[PyArray_STRING][PyArray_UNICODE] = 1;
+
+#ifndef NPY_SIZEOF_BYTE
+#define NPY_SIZEOF_BYTE 1
+#endif
+
+ /* Compile-time loop of casting rules */
+/**begin repeat
+ * #FROM_NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
+ * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
+ * #FROM_BASENAME = BYTE, BYTE, SHORT, SHORT, INT, INT, LONG, LONG,
+ * LONGLONG, LONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * FLOAT, DOUBLE, LONGDOUBLE#
+ * #from_isint = 1, 0, 1, 0, 1, 0, 1, 0,
+ * 1, 0, 0, 0, 0,
+ * 0, 0, 0#
+ * #from_isuint = 0, 1, 0, 1, 0, 1, 0, 1,
+ * 0, 1, 0, 0, 0,
+ * 0, 0, 0#
+ * #from_isfloat = 0, 0, 0, 0, 0, 0, 0, 0,
+ * 0, 0, 1, 1, 1,
+ * 0, 0, 0#
+ * #from_iscomplex = 0, 0, 0, 0, 0, 0, 0, 0,
+ * 0, 0, 0, 0, 0,
+ * 1, 1, 1#
+ */
+#define _FROM_BSIZE NPY_SIZEOF_@FROM_BASENAME@
+#define _FROM_NUM (PyArray_@FROM_NAME@)
+
+ _npy_can_cast_safely_table[_FROM_NUM][PyArray_STRING] = 1;
+ _npy_can_cast_safely_table[_FROM_NUM][PyArray_UNICODE] = 1;
+
+/**begin repeat1
+ * #TO_NAME = BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
+ * LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * CFLOAT, CDOUBLE, CLONGDOUBLE#
+ * #TO_BASENAME = BYTE, BYTE, SHORT, SHORT, INT, INT, LONG, LONG,
+ * LONGLONG, LONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
+ * FLOAT, DOUBLE, LONGDOUBLE#
+ * #to_isint = 1, 0, 1, 0, 1, 0, 1, 0,
+ * 1, 0, 0, 0, 0,
+ * 0, 0, 0#
+ * #to_isuint = 0, 1, 0, 1, 0, 1, 0, 1,
+ * 0, 1, 0, 0, 0,
+ * 0, 0, 0#
+ * #to_isfloat = 0, 0, 0, 0, 0, 0, 0, 0,
+ * 0, 0, 1, 1, 1,
+ * 0, 0, 0#
+ * #to_iscomplex = 0, 0, 0, 0, 0, 0, 0, 0,
+ * 0, 0, 0, 0, 0,
+ * 1, 1, 1#
+ */
+#define _TO_BSIZE NPY_SIZEOF_@TO_BASENAME@
+#define _TO_NUM (PyArray_@TO_NAME@)
+
+/*
+ * NOTE: _FROM_BSIZE and _TO_BSIZE are the sizes of the "base type"
+ * which is the same as the size of the type except for
+ * complex, where it is the size of the real type.
+ */
+
+#if @from_isint@
+
+# if @to_isint@ && (_TO_BSIZE >= _FROM_BSIZE)
+ /* int -> int */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_isfloat@ && (_FROM_BSIZE < 8) && (_TO_BSIZE > _FROM_BSIZE)
+ /* int -> float */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_isfloat@ && (_FROM_BSIZE >= 8) && (_TO_BSIZE >= _FROM_BSIZE)
+ /* int -> float */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_iscomplex@ && (_FROM_BSIZE < 8) && (_TO_BSIZE > _FROM_BSIZE)
+ /* int -> complex */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_iscomplex@ && (_FROM_BSIZE >= 8) && (_TO_BSIZE >= _FROM_BSIZE)
+ /* int -> complex */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# endif
+
+#elif @from_isuint@
+
+# if @to_isint@ && (_TO_BSIZE > _FROM_BSIZE)
+ /* uint -> int */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_isuint@ && (_TO_BSIZE >= _FROM_BSIZE)
+ /* uint -> uint */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_isfloat@ && (_FROM_BSIZE < 8) && (_TO_BSIZE > _FROM_BSIZE)
+ /* uint -> float */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_isfloat@ && (_FROM_BSIZE >= 8) && (_TO_BSIZE >= _FROM_BSIZE)
+ /* uint -> float */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_iscomplex@ && (_FROM_BSIZE < 8) && (_TO_BSIZE > _FROM_BSIZE)
+ /* uint -> complex */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_iscomplex@ && (_FROM_BSIZE >= 8) && (_TO_BSIZE >= _FROM_BSIZE)
+ /* uint -> complex */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# endif
+
+
+#elif @from_isfloat@
+
+# if @to_isfloat@ && (_TO_BSIZE >= _FROM_BSIZE)
+ /* float -> float */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# elif @to_iscomplex@ && (_TO_BSIZE >= _FROM_BSIZE)
+ /* float -> complex */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# endif
+
+#elif @from_iscomplex@
+
+# if @to_iscomplex@ && (_TO_BSIZE >= _FROM_BSIZE)
+ /* complex -> complex */
+ _npy_can_cast_safely_table[_FROM_NUM][_TO_NUM] = 1;
+# endif
+
+#endif
+
+#undef _TO_NUM
+#undef _TO_BSIZE
+
+/**end repeat1**/
+
+#undef _FROM_NUM
+#undef _FROM_BSIZE
+
+/**end repeat**/
+
+}
+
static PyNumberMethods longdoubletype_as_number;
static PyNumberMethods clongdoubletype_as_number;
diff --git a/numpy/core/src/multiarray/scalartypes.h b/numpy/core/src/multiarray/scalartypes.h
index 893c0051d..c60f61dfb 100644
--- a/numpy/core/src/multiarray/scalartypes.h
+++ b/numpy/core/src/multiarray/scalartypes.h
@@ -1,6 +1,22 @@
#ifndef _NPY_SCALARTYPES_H_
#define _NPY_SCALARTYPES_H_
+/* Internal look-up tables */
+#ifdef NPY_ENABLE_SEPARATE_COMPILATION
+extern NPY_NO_EXPORT unsigned char
+_npy_can_cast_safely_table[NPY_NTYPES][NPY_NTYPES];
+extern NPY_NO_EXPORT char
+_npy_scalar_kinds_table[NPY_NTYPES];
+#else
+NPY_NO_EXPORT unsigned char
+_npy_can_cast_safely_table[NPY_NTYPES][NPY_NTYPES];
+NPY_NO_EXPORT char
+_npy_scalar_kinds_table[NPY_NTYPES];
+#endif
+
+NPY_NO_EXPORT void
+initialize_casting_tables(void);
+
NPY_NO_EXPORT void
initialize_numeric_types(void);
diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c
index 203792914..594722695 100644
--- a/numpy/core/src/multiarray/usertypes.c
+++ b/numpy/core/src/multiarray/usertypes.c
@@ -225,6 +225,19 @@ NPY_NO_EXPORT int
PyArray_RegisterCanCast(PyArray_Descr *descr, int totype,
NPY_SCALARKIND scalar)
{
+ /*
+ * If we were to allow this, the casting lookup table for
+ * built-in types needs to be modified, as cancastto is
+ * not checked for them.
+ */
+ if (!PyTypeNum_ISUSERDEF(descr->type_num) &&
+ !PyTypeNum_ISUSERDEF(totype)) {
+ PyErr_SetString(PyExc_ValueError,
+ "At least one of the types provided to"
+ "RegisterCanCast must be user-defined.");
+ return -1;
+ }
+
if (scalar == PyArray_NOSCALAR) {
/*
* register with cancastto
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index fec9895b6..d8f179454 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -330,6 +330,49 @@ class TestFloatExceptions(TestCase):
finally:
np.seterr(**oldsettings)
+class TestCoercion(TestCase):
+ def test_coercion(self):
+ """Tests that the scalars get coerced correctly."""
+ i8, i16, i32, i64 = int8(0), int16(0), int32(0), int64(0)
+ u8, u16, u32, u64 = uint8(0), uint16(0), uint32(0), uint64(0)
+ f32, f64, fld = float32(0), float64(0), longdouble(0)
+ c64, c128, cld = complex64(0), complex128(0), clongdouble(0)
+
+ # coercion within the same type
+ assert_equal(np.add(i8,i16).dtype, int16)
+ assert_equal(np.add(i32,i8).dtype, int32)
+ assert_equal(np.add(i16,i64).dtype, int64)
+ assert_equal(np.add(u8,u32).dtype, uint32)
+ assert_equal(np.add(f32,f64).dtype, float64)
+ assert_equal(np.add(fld,f32).dtype, longdouble)
+ assert_equal(np.add(f64,fld).dtype, longdouble)
+ assert_equal(np.add(c128,c64).dtype, complex128)
+ assert_equal(np.add(cld,c128).dtype, clongdouble)
+ assert_equal(np.add(c64,fld).dtype, clongdouble)
+
+ # coercion between types
+ assert_equal(np.add(i8,u8).dtype, int16)
+ assert_equal(np.add(u8,i32).dtype, int32)
+ assert_equal(np.add(i64,u32).dtype, int64)
+ assert_equal(np.add(u64,i32).dtype, float64)
+ assert_equal(np.add(i32,f32).dtype, float64)
+ assert_equal(np.add(i64,f32).dtype, float64)
+ assert_equal(np.add(f32,i16).dtype, float32)
+ assert_equal(np.add(f32,u32).dtype, float64)
+ assert_equal(np.add(f32,c64).dtype, complex64)
+ assert_equal(np.add(c128,f32).dtype, complex128)
+ assert_equal(np.add(cld,f64).dtype, clongdouble)
+
+ # coercion between scalars and 1-D arrays
+ assert_equal(np.add(array([i8]),i64).dtype, int8)
+ assert_equal(np.add(u64,array([i32])).dtype, int32)
+ assert_equal(np.add(i64,array([u32])).dtype, uint32)
+ assert_equal(np.add(int32(-1),array([u64])).dtype, float64)
+ assert_equal(np.add(f64,array([f32])).dtype, float32)
+ assert_equal(np.add(fld,array([f32])).dtype, float32)
+ assert_equal(np.add(array([f64]),fld).dtype, float64)
+ assert_equal(np.add(fld,array([c64])).dtype, complex64)
+ assert_equal(np.add(c64,array([f64])).dtype, complex128)
class TestFromiter(TestCase):
def makegen(self):
diff --git a/numpy/testing/print_coercion_tables.py b/numpy/testing/print_coercion_tables.py
new file mode 100755
index 000000000..0c8a87d9a
--- /dev/null
+++ b/numpy/testing/print_coercion_tables.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python
+"""Prints type-coercion tables for the built-in NumPy types"""
+
+import numpy as np
+
+# Generic object that can be added, but doesn't do anything else
+class GenericObject:
+ def __init__(self, v):
+ self.v = v
+
+ def __add__(self, other):
+ return self
+
+ def __radd__(self, other):
+ return self
+
+def print_cancast_table(ntypes):
+ print 'X',
+ for char in ntypes: print char,
+ print
+ for row in ntypes:
+ print row,
+ for col in ntypes:
+ print int(np.can_cast(row, col)),
+ print
+
+def print_coercion_table(ntypes, inputfirstvalue, inputsecondvalue, firstarray):
+ print '+',
+ for char in ntypes: print char,
+ print
+ for row in ntypes:
+ if row == 'O':
+ rowtype = GenericObject
+ else:
+ rowtype = np.obj2sctype(row)
+
+ print row,
+ for col in ntypes:
+ if col == 'O':
+ coltype = GenericObject
+ else:
+ coltype = np.obj2sctype(col)
+ try:
+ if firstarray:
+ rowvalue = np.array([rowtype(inputfirstvalue)], dtype=rowtype)
+ else:
+ rowvalue = rowtype(inputfirstvalue)
+ colvalue = coltype(inputsecondvalue)
+ value = np.add(rowvalue,colvalue)
+ if isinstance(value, np.ndarray):
+ char = value.dtype.char
+ else:
+ char = np.dtype(type(value)).char
+ except ValueError:
+ char = '!'
+ except OverflowError:
+ char = '@'
+ except TypeError:
+ char = '#'
+ print char,
+ print
+
+print "can cast"
+print_cancast_table(np.typecodes['All'])
+print
+print "In these tables, ValueError is '!', OverflowError is '@', TypeError is '#'"
+print
+print "scalar + scalar"
+print_coercion_table(np.typecodes['All'], 0, 0, False)
+print
+print "scalar + neg scalar"
+print_coercion_table(np.typecodes['All'], 0, -1, False)
+print
+print "array + scalar"
+print_coercion_table(np.typecodes['All'], 0, 0, True)
+print
+print "array + neg scalar"
+print_coercion_table(np.typecodes['All'], 0, -1, True)
+