diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2019-12-02 14:17:39 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-02 14:17:39 -0600 |
commit | 5992098524c9f36288093ef3298d44343735842e (patch) | |
tree | bbcf59bcd293b763841c23f076927c0aa96533d8 /numpy | |
parent | 7b2d968d5a4730489d9e9148afe2277b1bc32477 (diff) | |
parent | 14bcfd9cfe0deb4e6499b398d7eba4d7e3dd7fe8 (diff) | |
download | numpy-5992098524c9f36288093ef3298d44343735842e.tar.gz |
Merge pull request #14996 from eric-wieser/masked_rows-bad-argument
DEP: Deprecate the axis argument to masked_rows and masked_cols
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/extras.py | 16 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 13 |
2 files changed, 27 insertions, 2 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 4a83ac781..f4a914471 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -937,7 +937,7 @@ def compress_cols(a): raise NotImplementedError("compress_cols works for 2D arrays only.") return compress_rowcols(a, 1) -def mask_rows(a, axis=None): +def mask_rows(a, axis=np._NoValue): """ Mask rows of a 2D array that contain masked values. @@ -979,9 +979,15 @@ def mask_rows(a, axis=None): fill_value=1) """ + if axis is not np._NoValue: + # remove the axis argument when this deprecation expires + # NumPy 1.18.0, 2019-11-28 + warnings.warn( + "The axis argument has always been ignored, in future passing it " + "will raise TypeError", DeprecationWarning, stacklevel=2) return mask_rowcols(a, 0) -def mask_cols(a, axis=None): +def mask_cols(a, axis=np._NoValue): """ Mask columns of a 2D array that contain masked values. @@ -1022,6 +1028,12 @@ def mask_cols(a, axis=None): fill_value=1) """ + if axis is not np._NoValue: + # remove the axis argument when this deprecation expires + # NumPy 1.18.0, 2019-11-28 + warnings.warn( + "The axis argument has always been ignored, in future passing it " + "will raise TypeError", DeprecationWarning, stacklevel=2) return mask_rowcols(a, 1) diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index 836770378..c75c47801 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -11,6 +11,7 @@ from __future__ import division, absolute_import, print_function import warnings import itertools +import pytest import numpy as np from numpy.testing import ( @@ -552,6 +553,18 @@ class TestCompressFunctions(object): assert_(mask_rowcols(x, 0).mask.all()) assert_(mask_rowcols(x, 1).mask.all()) + @pytest.mark.parametrize("axis", [None, 0, 1]) + @pytest.mark.parametrize(["func", "rowcols_axis"], + [(np.ma.mask_rows, 0), (np.ma.mask_cols, 1)]) + def test_mask_row_cols_axis_deprecation(self, axis, func, rowcols_axis): + # Test deprecation of the axis argument to `mask_rows` and `mask_cols` + x = array(np.arange(9).reshape(3, 3), + mask=[[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + + with assert_warns(DeprecationWarning): + res = func(x, axis=axis) + assert_equal(res, mask_rowcols(x, rowcols_axis)) + def test_dot(self): # Tests dot product n = np.arange(1, 7) |