diff options
Diffstat (limited to 'numpy/core/src/multiarraymodule.c')
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index d55987959..0d6f2e4b4 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -216,6 +216,39 @@ power_of_ten(int n) static PyObject * PyArray_Round(PyArrayObject *a, int decimals) { + if (PyArray_ISCOMPLEX(a)) { + PyObject *part; + PyObject *round_part; + PyObject *new; + int res; + new = PyArray_Copy(a); + if (new == NULL) return NULL; + + /* new.real = a.real.round(decimals) */ + part = PyObject_GetAttrString(new, "real"); + if (part == NULL) {Py_DECREF(new); return NULL;} + round_part = PyArray_Round\ + ((PyArrayObject *)PyArray_EnsureAnyArray(part), + decimals); + Py_DECREF(part); + if (round_part == NULL) {Py_DECREF(new); return NULL;} + res = PyObject_SetAttrString(new, "real", round_part); + Py_DECREF(round_part); + if (res < 0) {Py_DECREF(new); return NULL;} + + /* new.imag = a.imag.round(decimals) */ + part = PyObject_GetAttrString(new, "imag"); + if (part == NULL) {Py_DECREF(new); return NULL;} + round_part = PyArray_Round\ + ((PyArrayObject *)PyArray_EnsureAnyArray(part), + decimals); + Py_DECREF(part); + if (round_part == NULL) {Py_DECREF(new); return NULL;} + res = PyObject_SetAttrString(new, "imag", round_part); + Py_DECREF(round_part); + if (res < 0) {Py_DECREF(new); return NULL;} + return new; + } /* do the most common case first */ if (decimals == 0) { if (PyArray_ISINTEGER(a)) { |