diff options
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 23 | ||||
-rw-r--r-- | numpy/core/tests/test_maskna.py | 9 |
2 files changed, 31 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index d23ecef83..003be5e39 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1709,7 +1709,9 @@ _prepend_ones(PyArrayObject *arr, int nd, int ndmin) Py_INCREF(dtype); ret = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(arr), dtype, ndmin, newdims, newstrides, - PyArray_DATA(arr), PyArray_FLAGS(arr), (PyObject *)arr); + PyArray_DATA(arr), + PyArray_FLAGS(arr) & ~(NPY_ARRAY_MASKNA | NPY_ARRAY_OWNMASKNA), + (PyObject *)arr); if (ret == NULL) { return NULL; } @@ -1718,6 +1720,25 @@ _prepend_ones(PyArrayObject *arr, int nd, int ndmin) Py_DECREF(ret); return NULL; } + + /* Take a view of the NA mask as well if necessary */ + if (PyArray_HASMASKNA(arr)) { + PyArrayObject_fieldaccess *fret = (PyArrayObject_fieldaccess *)ret; + + fret->maskna_dtype = PyArray_MASKNA_DTYPE(arr); + Py_INCREF(fret->maskna_dtype); + fret->maskna_data = PyArray_MASKNA_DATA(arr); + + for (i = 0; i < num; ++i) { + fret->maskna_strides[i] = 0; + } + for (i = num; i < ndmin; ++i) { + fret->maskna_strides[i] = PyArray_MASKNA_STRIDES(arr)[i - num]; + } + fret->flags |= NPY_ARRAY_MASKNA; + } + + return (PyObject *)ret; } diff --git a/numpy/core/tests/test_maskna.py b/numpy/core/tests/test_maskna.py index 8c0f9272a..0f63cc8bd 100644 --- a/numpy/core/tests/test_maskna.py +++ b/numpy/core/tests/test_maskna.py @@ -943,5 +943,14 @@ def test_array_maskna_concatenate(): assert_equal(res[~np.isna(res)], [0,1,4,2,5,10]) assert_equal(res.strides, (4, 16)) +def test_array_maskna_column_stack(): + a = np.array((1,2,3), maskna=True) + b = np.array((2,3,4), maskna=True) + b[2] = np.NA + res = np.column_stack((a,b)) + assert_equal(np.isna(res), [[0,0], [0,0], [0,1]]) + assert_equal(res[~np.isna(res)], [1,2,2,3,3]) + + if __name__ == "__main__": run_module_suite() |