summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src34
-rw-r--r--numpy/core/tests/test_regression.py5
2 files changed, 39 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index f4212eb04..3cc3e3d19 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1468,6 +1468,22 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args),
return Py_None;
}
+/* casting complex numbers (that don't inherit from Python complex)
+ * to Python complex */
+
+/**begin repeat
+ * #name=cfloat,clongdouble#
+ * #Name=CFloat,CLongDouble#
+ */
+static PyObject *
+@name@_complex(PyObject *self, PyObject *NPY_UNUSED(args),
+ PyObject *NPY_UNUSED(kwds))
+{
+ return PyComplex_FromDoubles(PyArrayScalar_VAL(self, @Name@).real,
+ PyArrayScalar_VAL(self, @Name@).imag);
+}
+/**end repeat**/
+
/* need to fill in doc-strings for these methods on import -- copy from
array docstrings
*/
@@ -1687,6 +1703,17 @@ static PyMethodDef voidtype_methods[] = {
{NULL, NULL, 0, NULL}
};
+/**begin repeat
+ * #name=cfloat,clongdouble#
+ */
+static PyMethodDef @name@type_methods[] = {
+ {"__complex__",
+ (PyCFunction)@name@_complex,
+ METH_VARARGS | METH_KEYWORDS, NULL},
+ {NULL, NULL, 0, NULL}
+};
+/**end repeat**/
+
/************* As_mapping functions for void array scalar ************/
static Py_ssize_t
@@ -3307,6 +3334,13 @@ initialize_numeric_types(void)
Py@NAME@ArrType_Type.tp_hash = @name@_arrtype_hash;
/**end repeat**/
+ /**begin repeat
+ * #name = cfloat, clongdouble#
+ * #NAME = CFloat, CLongDouble#
+ */
+ Py@NAME@ArrType_Type.tp_methods = @name@type_methods;
+ /**end repeat**/
+
#if (SIZEOF_INT != SIZEOF_LONG) || defined(NPY_PY3K)
/* We won't be inheriting from Python Int type. */
PyIntArrType_Type.tp_hash = int_arrtype_hash;
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 9375ab7e4..49490c05a 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1411,5 +1411,10 @@ class TestRegression(TestCase):
assert_equal(float(x), float(x.real))
ctx.__exit__()
+ def test_complex_scalar_complex_cast(self):
+ for tp in [np.csingle, np.cdouble, np.clongdouble]:
+ x = tp(1+2j)
+ assert_equal(complex(x), 1+2j)
+
if __name__ == "__main__":
run_module_suite()