diff options
author | Antoine Pitrou <antoine@python.org> | 2015-04-23 12:01:44 +0200 |
---|---|---|
committer | Antoine Pitrou <antoine@python.org> | 2015-04-23 12:01:44 +0200 |
commit | d6b20d1de81975293373f9bf234358f7338ef0a0 (patch) | |
tree | ab8a3cce0b982f9cefe1a10e87d9abed7ee52474 | |
parent | 02b858326dac217607a83ed0bf4d7d51d5bfbfbe (diff) | |
download | numpy-d6b20d1de81975293373f9bf234358f7338ef0a0.tar.gz |
BUG: fix round() on complex array with explicit output
Closes #5779.
-rw-r--r-- | numpy/core/src/multiarray/calculation.c | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 21 |
2 files changed, 18 insertions, 7 deletions
diff --git a/numpy/core/src/multiarray/calculation.c b/numpy/core/src/multiarray/calculation.c index edcca9857..d4a08a4ee 100644 --- a/numpy/core/src/multiarray/calculation.c +++ b/numpy/core/src/multiarray/calculation.c @@ -618,7 +618,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out) } /* arr.real = a.real.round(decimals) */ - part = PyObject_GetAttrString(arr, "real"); + part = PyObject_GetAttrString(a, "real"); if (part == NULL) { Py_DECREF(arr); return NULL; @@ -639,7 +639,7 @@ PyArray_Round(PyArrayObject *a, int decimals, PyArrayObject *out) } /* arr.imag = a.imag.round(decimals) */ - part = PyObject_GetAttrString(arr, "imag"); + part = PyObject_GetAttrString(a, "imag"); if (part == NULL) { Py_DECREF(arr); return NULL; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 3ee17859f..314adf4d1 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -868,11 +868,22 @@ class TestBool(TestCase): self.assertEqual(np.count_nonzero(a), builtins.sum(a.tolist())) class TestMethods(TestCase): - def test_test_round(self): - assert_equal(array([1.2, 1.5]).round(), [1, 2]) - assert_equal(array(1.5).round(), 2) - assert_equal(array([12.2, 15.5]).round(-1), [10, 20]) - assert_equal(array([12.15, 15.51]).round(1), [12.2, 15.5]) + def test_round(self): + def check_round(arr, expected, *round_args): + assert_equal(arr.round(*round_args), expected) + # With output array + out = np.zeros_like(arr) + res = arr.round(*round_args, out=out) + assert_equal(out, expected) + assert_equal(out, res) + + check_round(array([1.2, 1.5]), [1, 2]) + check_round(array(1.5), 2) + check_round(array([12.2, 15.5]), [10, 20], -1) + check_round(array([12.15, 15.51]), [12.2, 15.5], 1) + # Complex rounding + check_round(array([4.5 + 1.5j]), [4 + 2j]) + check_round(array([12.5 + 15.5j]), [10 + 20j], -1) def test_transpose(self): a = array([[1, 2], [3, 4]]) |