diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-09-21 01:43:09 +0200 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-09-21 01:46:32 +0200 |
commit | 14d8e209b0fef60fdda9ed38e4a3dec8ddf91996 (patch) | |
tree | 449f2f55b3893617a2f9f0e6e6821dfb3eb9381a /numpy | |
parent | d82003ba6ab7c0037cadbcee81d4a24463f33589 (diff) | |
download | numpy-14d8e209b0fef60fdda9ed38e4a3dec8ddf91996.tar.gz |
BUG: core: ensure cfloat and clongdouble scalars have a __complex__ method, so that complex(...) cast works properly (fixes #1617)
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 34 | ||||
-rw-r--r-- | numpy/core/tests/test_regression.py | 5 |
2 files changed, 39 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index f4212eb04..3cc3e3d19 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1468,6 +1468,22 @@ gentype_setflags(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), return Py_None; } +/* casting complex numbers (that don't inherit from Python complex) + * to Python complex */ + +/**begin repeat + * #name=cfloat,clongdouble# + * #Name=CFloat,CLongDouble# + */ +static PyObject * +@name@_complex(PyObject *self, PyObject *NPY_UNUSED(args), + PyObject *NPY_UNUSED(kwds)) +{ + return PyComplex_FromDoubles(PyArrayScalar_VAL(self, @Name@).real, + PyArrayScalar_VAL(self, @Name@).imag); +} +/**end repeat**/ + /* need to fill in doc-strings for these methods on import -- copy from array docstrings */ @@ -1687,6 +1703,17 @@ static PyMethodDef voidtype_methods[] = { {NULL, NULL, 0, NULL} }; +/**begin repeat + * #name=cfloat,clongdouble# + */ +static PyMethodDef @name@type_methods[] = { + {"__complex__", + (PyCFunction)@name@_complex, + METH_VARARGS | METH_KEYWORDS, NULL}, + {NULL, NULL, 0, NULL} +}; +/**end repeat**/ + /************* As_mapping functions for void array scalar ************/ static Py_ssize_t @@ -3307,6 +3334,13 @@ initialize_numeric_types(void) Py@NAME@ArrType_Type.tp_hash = @name@_arrtype_hash; /**end repeat**/ + /**begin repeat + * #name = cfloat, clongdouble# + * #NAME = CFloat, CLongDouble# + */ + Py@NAME@ArrType_Type.tp_methods = @name@type_methods; + /**end repeat**/ + #if (SIZEOF_INT != SIZEOF_LONG) || defined(NPY_PY3K) /* We won't be inheriting from Python Int type. */ PyIntArrType_Type.tp_hash = int_arrtype_hash; diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py index 9375ab7e4..49490c05a 100644 --- a/numpy/core/tests/test_regression.py +++ b/numpy/core/tests/test_regression.py @@ -1411,5 +1411,10 @@ class TestRegression(TestCase): assert_equal(float(x), float(x.real)) ctx.__exit__() + def test_complex_scalar_complex_cast(self): + for tp in [np.csingle, np.cdouble, np.clongdouble]: + x = tp(1+2j) + assert_equal(complex(x), 1+2j) + if __name__ == "__main__": run_module_suite() |