summaryrefslogtreecommitdiff
path: root/numpy/core/src/multiarraymodule.c
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/src/multiarraymodule.c')
-rw-r--r--numpy/core/src/multiarraymodule.c33
1 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index d55987959..0d6f2e4b4 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -216,6 +216,39 @@ power_of_ten(int n)
static PyObject *
PyArray_Round(PyArrayObject *a, int decimals)
{
+ if (PyArray_ISCOMPLEX(a)) {
+ PyObject *part;
+ PyObject *round_part;
+ PyObject *new;
+ int res;
+ new = PyArray_Copy(a);
+ if (new == NULL) return NULL;
+
+ /* new.real = a.real.round(decimals) */
+ part = PyObject_GetAttrString(new, "real");
+ if (part == NULL) {Py_DECREF(new); return NULL;}
+ round_part = PyArray_Round\
+ ((PyArrayObject *)PyArray_EnsureAnyArray(part),
+ decimals);
+ Py_DECREF(part);
+ if (round_part == NULL) {Py_DECREF(new); return NULL;}
+ res = PyObject_SetAttrString(new, "real", round_part);
+ Py_DECREF(round_part);
+ if (res < 0) {Py_DECREF(new); return NULL;}
+
+ /* new.imag = a.imag.round(decimals) */
+ part = PyObject_GetAttrString(new, "imag");
+ if (part == NULL) {Py_DECREF(new); return NULL;}
+ round_part = PyArray_Round\
+ ((PyArrayObject *)PyArray_EnsureAnyArray(part),
+ decimals);
+ Py_DECREF(part);
+ if (round_part == NULL) {Py_DECREF(new); return NULL;}
+ res = PyObject_SetAttrString(new, "imag", round_part);
+ Py_DECREF(round_part);
+ if (res < 0) {Py_DECREF(new); return NULL;}
+ return new;
+ }
/* do the most common case first */
if (decimals == 0) {
if (PyArray_ISINTEGER(a)) {