summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAntoine Pitrou <antoine@python.org>2015-04-23 12:01:44 +0200
committerAntoine Pitrou <antoine@python.org>2015-04-23 12:01:44 +0200
commitd6b20d1de81975293373f9bf234358f7338ef0a0 (patch)
treeab8a3cce0b982f9cefe1a10e87d9abed7ee52474
parent02b858326dac217607a83ed0bf4d7d51d5bfbfbe (diff)
downloadnumpy-d6b20d1de81975293373f9bf234358f7338ef0a0.tar.gz
BUG: fix round() on complex array with explicit output
Closes #5779.
-rw-r--r--numpy/core/src/multiarray/calculation.c4
-rw-r--r--numpy/core/tests/test_multiarray.py21
2 files changed, 18 insertions, 7 deletions
diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c
index edcca9857..d4a08a4ee 100644
--- a/numpy/core/src/multiarray/calculation.c
+++ b/numpy/core/src/multiarray/calculation.c
@@ -618,7 +618,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out)
}
/* arr.real = a.real.round(decimals) */
- part = PyObject_GetAttrString(arr, "real");
+ part = PyObject_GetAttrString(a, "real");
if (part == NULL) {
Py_DECREF(arr);
return NULL;
@@ -639,7 +639,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out)
}
/* arr.imag = a.imag.round(decimals) */
- part = PyObject_GetAttrString(arr, "imag");
+ part = PyObject_GetAttrString(a, "imag");
if (part == NULL) {
Py_DECREF(arr);
return NULL;
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 3ee17859f..314adf4d1 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -868,11 +868,22 @@ class TestBool(TestCase):
self.assertEqual(np.count_nonzero(a), builtins.sum(a.tolist()))
class TestMethods(TestCase):
- def test_test_round(self):
- assert_equal(array([1.2, 1.5]).round(), [1, 2])
- assert_equal(array(1.5).round(), 2)
- assert_equal(array([12.2, 15.5]).round(-1), [10, 20])
- assert_equal(array([12.15, 15.51]).round(1), [12.2, 15.5])
+ def test_round(self):
+ def check_round(arr, expected, *round_args):
+ assert_equal(arr.round(*round_args), expected)
+ # With output array
+ out = np.zeros_like(arr)
+ res = arr.round(*round_args, out=out)
+ assert_equal(out, expected)
+ assert_equal(out, res)
+
+ check_round(array([1.2, 1.5]), [1, 2])
+ check_round(array(1.5), 2)
+ check_round(array([12.2, 15.5]), [10, 20], -1)
+ check_round(array([12.15, 15.51]), [12.2, 15.5], 1)
+ # Complex rounding
+ check_round(array([4.5 + 1.5j]), [4 + 2j])
+ check_round(array([12.5 + 15.5j]), [10 + 20j], -1)
def test_transpose(self):
a = array([[1, 2], [3, 4]])