summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c23
-rw-r--r--numpy/core/tests/test_maskna.py9
2 files changed, 31 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index d23ecef83..003be5e39 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -1709,7 +1709,9 @@ _prepend_ones(PyArrayObject *arr, int nd, int ndmin)
Py_INCREF(dtype);
ret = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(arr),
dtype, ndmin, newdims, newstrides,
- PyArray_DATA(arr), PyArray_FLAGS(arr), (PyObject *)arr);
+ PyArray_DATA(arr),
+ PyArray_FLAGS(arr) & ~(NPY_ARRAY_MASKNA | NPY_ARRAY_OWNMASKNA),
+ (PyObject *)arr);
if (ret == NULL) {
return NULL;
}
@@ -1718,6 +1720,25 @@ _prepend_ones(PyArrayObject *arr, int nd, int ndmin)
Py_DECREF(ret);
return NULL;
}
+
+ /* Take a view of the NA mask as well if necessary */
+ if (PyArray_HASMASKNA(arr)) {
+ PyArrayObject_fieldaccess *fret = (PyArrayObject_fieldaccess *)ret;
+
+ fret->maskna_dtype = PyArray_MASKNA_DTYPE(arr);
+ Py_INCREF(fret->maskna_dtype);
+ fret->maskna_data = PyArray_MASKNA_DATA(arr);
+
+ for (i = 0; i < num; ++i) {
+ fret->maskna_strides[i] = 0;
+ }
+ for (i = num; i < ndmin; ++i) {
+ fret->maskna_strides[i] = PyArray_MASKNA_STRIDES(arr)[i - num];
+ }
+ fret->flags |= NPY_ARRAY_MASKNA;
+ }
+
+
return (PyObject *)ret;
}
diff --git a/numpy/core/tests/test_maskna.py b/numpy/core/tests/test_maskna.py
index 8c0f9272a..0f63cc8bd 100644
--- a/numpy/core/tests/test_maskna.py
+++ b/numpy/core/tests/test_maskna.py
@@ -943,5 +943,14 @@ def test_array_maskna_concatenate():
assert_equal(res[~np.isna(res)], [0,1,4,2,5,10])
assert_equal(res.strides, (4, 16))
+def test_array_maskna_column_stack():
+ a = np.array((1,2,3), maskna=True)
+ b = np.array((2,3,4), maskna=True)
+ b[2] = np.NA
+ res = np.column_stack((a,b))
+ assert_equal(np.isna(res), [[0,0], [0,0], [0,1]])
+ assert_equal(res[~np.isna(res)], [1,2,2,3,3])
+
+
if __name__ == "__main__":
run_module_suite()