diff options
author | seberg <sebastian@sipsolutions.net> | 2016-03-23 14:14:08 +0100 |
---|---|---|
committer | seberg <sebastian@sipsolutions.net> | 2016-03-23 14:14:08 +0100 |
commit | 1380fdd8824f7b404a1a5b22aa4dd6bfbfcfbc1a (patch) | |
tree | 25d9bfc5bd08cbec64721cf6b594a211bd58b65e | |
parent | b479671bd6a3510e6823efc4d13e3eafcf7cf9dc (diff) | |
parent | 5ed19d18b2c719a045ccbadf0a7d95d08e1609c8 (diff) | |
download | numpy-1380fdd8824f7b404a1a5b22aa4dd6bfbfcfbc1a.tar.gz |
Merge pull request #7229 from ewmoore/__complex__
ENH: implement __complex__
-rw-r--r-- | doc/release/1.12.0-notes.rst | 6 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 74 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 41 |
3 files changed, 121 insertions, 0 deletions
diff --git a/doc/release/1.12.0-notes.rst b/doc/release/1.12.0-notes.rst index 31f7096d7..97ec25777 100644 --- a/doc/release/1.12.0-notes.rst +++ b/doc/release/1.12.0-notes.rst @@ -155,6 +155,12 @@ Generalized Ufuncs, including most of the linalg module, will now unlock the Python global interpreter lock. +The *__complex__* method has been implemented on the ndarray object +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Calling ``complex()`` on a size 1 array will now cast to a python +complex. + + Changes ======= diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 56b6086ff..b8e066b79 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2308,6 +2308,76 @@ array_newbyteorder(PyArrayObject *self, PyObject *args) } +static PyObject * +array_complex(PyArrayObject *self, PyObject *NPY_UNUSED(args)) +{ + PyArrayObject *arr; + PyArray_Descr *dtype; + PyObject *c; + if (PyArray_SIZE(self) != 1) { + PyErr_SetString(PyExc_TypeError, "only length-1 arrays can "\ + "be converted to Python scalars"); + return NULL; + } + + dtype = PyArray_DescrFromType(NPY_CDOUBLE); + if (dtype == NULL) { + return NULL; + } + + if (!PyArray_CanCastArrayTo(self, dtype, NPY_SAME_KIND_CASTING) && + !(PyArray_TYPE(self) == NPY_OBJECT)) { + PyObject *err, *msg_part; + Py_DECREF(dtype); + err = PyString_FromString("unable to convert "); + if (err == NULL) { + return NULL; + } + msg_part = PyObject_Repr((PyObject*)PyArray_DESCR(self)); + if (msg_part == NULL) { + Py_DECREF(err); + return NULL; + } + PyString_ConcatAndDel(&err, msg_part); + if (err == NULL) { + return NULL; + } + msg_part = PyString_FromString(", to complex."); + if (msg_part == NULL) { + Py_DECREF(err); + return NULL; + } + PyString_ConcatAndDel(&err, msg_part); + if (err == NULL) { + return NULL; + } + PyErr_SetObject(PyExc_TypeError, err); + Py_DECREF(err); + return NULL; + } + + if (PyArray_TYPE(self) == NPY_OBJECT) { + /* let python try calling __complex__ on the object. */ + PyObject *args, *res; + Py_DECREF(dtype); + args = Py_BuildValue("(O)", *((PyObject**)PyArray_DATA(self))); + if (args == NULL) { + return NULL; + } + res = PyComplex_Type.tp_new(&PyComplex_Type, args, NULL); + Py_DECREF(args); + return res; + } + + arr = (PyArrayObject *)PyArray_CastToType(self, dtype, 0); + if (arr == NULL) { + return NULL; + } + c = PyComplex_FromCComplex(*((Py_complex*)PyArray_DATA(arr))); + Py_DECREF(arr); + return c; +} + NPY_NO_EXPORT PyMethodDef array_methods[] = { /* for subtypes */ @@ -2348,6 +2418,10 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { (PyCFunction) array_dump, METH_VARARGS, NULL}, + {"__complex__", + (PyCFunction) array_complex, + METH_VARARGS, NULL}, + /* Original and Extended methods added 2005 */ {"all", (PyCFunction)array_all, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 14a902b9c..ffdc50d2c 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2476,6 +2476,47 @@ class TestMethods(TestCase): assert_raises(AttributeError, lambda: a.conj()) assert_raises(AttributeError, lambda: a.conjugate()) + def test__complex__(self): + dtypes = ['i1', 'i2', 'i4', 'i8', + 'u1', 'u2', 'u4', 'u8', + 'f', 'd', 'g', 'F', 'D', 'G', + '?', 'O'] + for dt in dtypes: + a = np.array(7, dtype=dt) + b = np.array([7], dtype=dt) + c = np.array([[[[[7]]]]], dtype=dt) + + msg = 'dtype: {0}'.format(dt) + ap = complex(a) + assert_equal(ap, a, msg) + bp = complex(b) + assert_equal(bp, b, msg) + cp = complex(c) + assert_equal(cp, c, msg) + + def test__complex__should_not_work(self): + dtypes = ['i1', 'i2', 'i4', 'i8', + 'u1', 'u2', 'u4', 'u8', + 'f', 'd', 'g', 'F', 'D', 'G', + '?', 'O'] + for dt in dtypes: + a = np.array([1, 2, 3], dtype=dt) + assert_raises(TypeError, complex, a) + + dt = np.dtype([('a', 'f8'), ('b', 'i1')]) + b = np.array((1.0, 3), dtype=dt) + assert_raises(TypeError, complex, b) + + c = np.array([(1.0, 3), (2e-3, 7)], dtype=dt) + assert_raises(TypeError, complex, c) + + d = np.array('1+1j') + assert_raises(TypeError, complex, d) + + e = np.array(['1+1j'], 'U') + assert_raises(TypeError, complex, e) + + class TestBinop(object): def test_inplace(self): # test refcount 1 inplace conversion |