summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Kirkham <kirkhamj@janelia.hhmi.org>2015-02-19 01:23:50 -0500
committerCharles Harris <charlesr.harris@gmail.com>2015-05-02 11:28:02 -0600
commit602836b99c2a046574dd352def24980b2da9d9b2 (patch)
tree0fcd5b2f14d371a68decd07339c40f437275c4ce
parentf3e7ef7f034bddcee856f6722db7035ba549ef19 (diff)
downloadnumpy-602836b99c2a046574dd352def24980b2da9d9b2.tar.gz
BUG: Fix `numpy.ma.where` to be consistent with unmasked version.
Closes #5679.
-rw-r--r--numpy/ma/core.py34
-rw-r--r--numpy/ma/tests/test_core.py7
2 files changed, 37 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index b52dad9ac..b52dc0ab7 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -6708,7 +6708,7 @@ size.__doc__ = np.size.__doc__
#####--------------------------------------------------------------------------
#---- --- Extra functions ---
#####--------------------------------------------------------------------------
-def where (condition, x=None, y=None):
+def where(condition, *args, **kwargs):
"""
Return a masked array with elements from x or y, depending on condition.
@@ -6755,10 +6755,36 @@ def where (condition, x=None, y=None):
[6.0 -- 8.0]]
"""
- if x is None and y is None:
+
+ if ((len(args) == 0) or (len(args) == 2)) and (len(kwargs) == 0):
+ pass
+ elif (len(args) == 1) and (len(kwargs) == 1):
+ if ("x" not in kwargs):
+ raise ValueError(
+ "Cannot provide `x` as an argument and a keyword argument."
+ )
+ if ("y" not in kwargs):
+ raise ValueError(
+ "Must provide `y` as a keyword argument if not as an argument."
+ )
+ args += (kwargs["y"],)
+ elif (len(args) == 0) and ((len(kwargs) == 0) or (len(kwargs) == 2)):
+ if (("x" not in kwargs) and ("y" not in kwargs) or
+ ("x" in kwargs) and ("y" in kwargs)):
+ raise ValueError(
+ "Must provide both `x` and `y` as keyword arguments or neither."
+ )
+ args += (kwargs["x"], kwargs["y"],)
+ else:
+ raise ValueError(
+ "Only takes 3 arguments was provided %i arguments." %
+ len(args) + len(kwargs)
+ )
+
+
+ if len(args) == 0:
return filled(condition, 0).nonzero()
- elif x is None or y is None:
- raise ValueError("Either both or neither x and y should be given.")
+ x, y = args
# Get the condition ...............
fc = filled(condition, 0).astype(MaskType)
notfc = np.logical_not(fc)
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index ea266669e..f8a28164e 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -3396,6 +3396,13 @@ class TestMaskedArrayFunctions(TestCase):
assert_equal(d, [-9, -9, -9, -9, -9, 4, -9, -9, 10, -9, -9, 3])
assert_equal(d.dtype, ixm.dtype)
+ def test_where_object(self):
+ a = np.array(None)
+ b = masked_array(None)
+ r = b.copy()
+ assert_equal(np.ma.where(True, a, a), r)
+ assert_equal(np.ma.where(True, b, b), r)
+
def test_where_with_masked_choice(self):
x = arange(10)
x[3] = masked