summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-02-10 15:00:56 +0000
committerEric Wieser <wieser.eric@gmail.com>2017-03-08 00:26:05 +0000
commit2d229ef2bbd4d65b03e62c6e3aa7396034ffd218 (patch)
tree79442a8fc535f642a8e0e91869a19543f2441b53 /numpy
parent6a3edf3210b439a4d1a51acb4e01bac017697ee6 (diff)
downloadnumpy-2d229ef2bbd4d65b03e62c6e3aa7396034ffd218.tar.gz
BUG: Make np.ma.where delegate to np.where
Fixes #8600 and #8599 Also makes np.ma.masked work with structured dtypes.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py64
-rw-r--r--numpy/ma/tests/test_core.py32
2 files changed, 63 insertions, 33 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index e78d1601d..f3d0c0b97 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -6991,44 +6991,42 @@ def where(condition, x=_NoValue, y=_NoValue):
[6.0 -- 8.0]]
"""
- missing = (x is _NoValue, y is _NoValue).count(True)
+ # handle the single-argument case
+ missing = (x is _NoValue, y is _NoValue).count(True)
if missing == 1:
raise ValueError("Must provide both 'x' and 'y' or neither.")
if missing == 2:
- return filled(condition, 0).nonzero()
-
- # Both x and y are provided
-
- # Get the condition
- fc = filled(condition, 0).astype(MaskType)
- notfc = np.logical_not(fc)
-
- # Get the data
- xv = getdata(x)
- yv = getdata(y)
- if x is masked:
- ndtype = yv.dtype
- elif y is masked:
- ndtype = xv.dtype
- else:
- ndtype = np.find_common_type([xv.dtype, yv.dtype], [])
-
- # Construct an empty array and fill it
- d = np.empty(fc.shape, dtype=ndtype).view(MaskedArray)
- np.copyto(d._data, xv.astype(ndtype), where=fc)
- np.copyto(d._data, yv.astype(ndtype), where=notfc)
-
- # Create an empty mask and fill it
- mask = np.zeros(fc.shape, dtype=MaskType)
- np.copyto(mask, getmask(x), where=fc)
- np.copyto(mask, getmask(y), where=notfc)
- mask |= getmaskarray(condition)
-
- # Use d._mask instead of d.mask to avoid copies
- d._mask = mask if mask.any() else nomask
+ return nonzero(condition)
+
+ # we only care if the condition is true - false or masked pick y
+ cf = filled(condition, False)
+ xd = getdata(x)
+ yd = getdata(y)
+
+ # we need the full arrays here for correct final dimensions
+ cm = getmaskarray(condition)
+ xm = getmaskarray(x)
+ ym = getmaskarray(y)
+
+ # deal with the fact that masked.dtype == float64, but we don't actually
+ # want to treat it as that.
+ if x is masked and y is not masked:
+ xd = np.zeros((), dtype=yd.dtype)
+ xm = np.ones((), dtype=ym.dtype)
+ elif y is masked and x is not masked:
+ yd = np.zeros((), dtype=xd.dtype)
+ ym = np.ones((), dtype=xm.dtype)
+
+ data = np.where(cf, xd, yd)
+ mask = np.where(cf, xm, ym)
+ mask = np.where(cm, np.ones((), dtype=mask.dtype), mask)
+
+ # collapse the mask, for backwards compatibility
+ if mask.dtype == np.bool_ and not mask.any():
+ mask = nomask
- return d
+ return masked_array(data, mask=mask)
def choose(indices, choices, out=None, mode='raise'):
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index a65cac8c8..794889b92 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -3942,6 +3942,38 @@ class TestMaskedArrayFunctions(TestCase):
control = np.find_common_type([np.int32, np.float32], [])
assert_equal(test, control)
+ def test_where_broadcast(self):
+ # Issue 8599
+ x = np.arange(9).reshape(3, 3)
+ y = np.zeros(3)
+ core = np.where([1, 0, 1], x, y)
+ ma = where([1, 0, 1], x, y)
+
+ assert_equal(core, ma)
+ assert_equal(core.dtype, ma.dtype)
+
+ def test_where_structured(self):
+ # Issue 8600
+ dt = np.dtype([('a', int), ('b', int)])
+ x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt)
+ y = np.array((10, 20), dtype=dt)
+ core = np.where([0, 1, 1], x, y)
+ ma = np.where([0, 1, 1], x, y)
+
+ assert_equal(core, ma)
+ assert_equal(core.dtype, ma.dtype)
+
+ def test_where_structured_masked(self):
+ dt = np.dtype([('a', int), ('b', int)])
+ x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt)
+
+ ma = where([0, 1, 1], x, masked)
+ expected = masked_where([1, 0, 0], x)
+
+ assert_equal(ma.dtype, expected.dtype)
+ assert_equal(ma, expected)
+ assert_equal(ma.mask, expected.mask)
+
def test_choose(self):
# Test choose
choices = [[0, 1, 2, 3], [10, 11, 12, 13],