diff options
author | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-06-04 23:58:38 +0200 |
---|---|---|
committer | Julian Taylor <jtaylor.debian@googlemail.com> | 2014-06-05 00:03:15 +0200 |
commit | e2e49c73acaf5bbfd7d6a7d412218991d9899cec (patch) | |
tree | 374603b99dec5d14ff4ee8125f51bb9ceec384d8 | |
parent | d856a7f8a1fdca371fe090d2eaf731d69f26e1dd (diff) | |
download | numpy-e2e49c73acaf5bbfd7d6a7d412218991d9899cec.tar.gz |
BUG: fix where not filling string types properly
the copyswap part of where used the input arrays descriptions to copy
into the destination so if they had a smaller size the destination was
not properly padded with zeros.
Closes gh-4778
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 9 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 13 |
2 files changed, 18 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index c34bdf583..682705a1b 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -2790,6 +2790,9 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) needs_api = NpyIter_IterationNeedsAPI(iter); + /* Get the result from the iterator object array */ + ret = (PyObject*)NpyIter_GetOperandArray(iter)[0]; + NPY_BEGIN_THREADS_NDITER(iter); if (NpyIter_GetIterSize(iter) != 0) { @@ -2836,10 +2839,10 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) npy_intp i; for (i = 0; i < n; i++) { if (*csrc) { - copyswapx(dst, xsrc, axswap, ax); + copyswapx(dst, xsrc, axswap, ret); } else { - copyswapy(dst, ysrc, ayswap, ay); + copyswapy(dst, ysrc, ayswap, ret); } dst += itemsize; xsrc += xstride; @@ -2852,8 +2855,6 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) NPY_END_THREADS; - /* Get the result from the iterator object array */ - ret = (PyObject*)NpyIter_GetOperandArray(iter)[0]; Py_INCREF(ret); Py_DECREF(arr); Py_DECREF(ax); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index bc3670653..2e40a2b7c 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -4354,6 +4354,19 @@ class TestWhere(TestCase): assert_raises(ValueError, np.where, c, a, a) assert_raises(ValueError, np.where, c[0], a, b) + def test_string(self): + # gh-4778 check strings are properly filled with nulls + a = np.array("abc") + b = np.array("x" * 753) + assert_equal(np.where(True, a, b), "abc") + assert_equal(np.where(False, b, a), "abc") + + # check native datatype sized strings + a = np.array("abcd") + b = np.array("x" * 8) + assert_equal(np.where(True, a, b), "abcd") + assert_equal(np.where(False, b, a), "abcd") + if __name__ == "__main__": run_module_suite() |