summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <jtaylor.debian@googlemail.com>2014-06-04 23:58:38 +0200
committerJulian Taylor <jtaylor.debian@googlemail.com>2014-06-05 00:03:15 +0200
commite2e49c73acaf5bbfd7d6a7d412218991d9899cec (patch)
tree374603b99dec5d14ff4ee8125f51bb9ceec384d8
parentd856a7f8a1fdca371fe090d2eaf731d69f26e1dd (diff)
downloadnumpy-e2e49c73acaf5bbfd7d6a7d412218991d9899cec.tar.gz
BUG: fix where not filling string types properly
the copyswap part of where used the input arrays descriptions to copy into the destination so if they had a smaller size the destination was not properly padded with zeros. Closes gh-4778
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c9
-rw-r--r--numpy/core/tests/test_multiarray.py13
2 files changed, 18 insertions, 4 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index c34bdf583..682705a1b 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -2790,6 +2790,9 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
needs_api = NpyIter_IterationNeedsAPI(iter);
+ /* Get the result from the iterator object array */
+ ret = (PyObject*)NpyIter_GetOperandArray(iter)[0];
+
NPY_BEGIN_THREADS_NDITER(iter);
if (NpyIter_GetIterSize(iter) != 0) {
@@ -2836,10 +2839,10 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
npy_intp i;
for (i = 0; i < n; i++) {
if (*csrc) {
- copyswapx(dst, xsrc, axswap, ax);
+ copyswapx(dst, xsrc, axswap, ret);
}
else {
- copyswapy(dst, ysrc, ayswap, ay);
+ copyswapy(dst, ysrc, ayswap, ret);
}
dst += itemsize;
xsrc += xstride;
@@ -2852,8 +2855,6 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
NPY_END_THREADS;
- /* Get the result from the iterator object array */
- ret = (PyObject*)NpyIter_GetOperandArray(iter)[0];
Py_INCREF(ret);
Py_DECREF(arr);
Py_DECREF(ax);
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index bc3670653..2e40a2b7c 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4354,6 +4354,19 @@ class TestWhere(TestCase):
assert_raises(ValueError, np.where, c, a, a)
assert_raises(ValueError, np.where, c[0], a, b)
+ def test_string(self):
+ # gh-4778 check strings are properly filled with nulls
+ a = np.array("abc")
+ b = np.array("x" * 753)
+ assert_equal(np.where(True, a, b), "abc")
+ assert_equal(np.where(False, b, a), "abc")
+
+ # check native datatype sized strings
+ a = np.array("abcd")
+ b = np.array("x" * 8)
+ assert_equal(np.where(True, a, b), "abcd")
+ assert_equal(np.where(False, b, a), "abcd")
+
if __name__ == "__main__":
run_module_suite()