summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-01-21 20:26:40 -0500
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-02-01 14:39:40 -0500
commit7927b3cb4b57e79bf6b1b8a723583295fc0398eb (patch)
tree74d8d82dd0b2cb29b6915e38a0142439691719b1
parentbd0fe0aa2159fa1d4570b0c524097b7492f239bf (diff)
downloadnumpy-7927b3cb4b57e79bf6b1b8a723583295fc0398eb.tar.gz
BUG Ensure that operations on MaskedArray return any output given.
Without this commit, after an operation like ``` a = np.ma.MaskedArray([1., 2., 3.]) b = np.add(a, 1., out=a) ``` one would have `b is not a`. With this PR, it is guaranteed that `b is a`.
-rw-r--r--numpy/ma/core.py11
-rw-r--r--numpy/ma/tests/test_core.py9
2 files changed, 16 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 49c790d4f..69c401e01 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -2988,8 +2988,11 @@ class MaskedArray(ndarray):
Wraps the numpy array and sets the mask according to context.
"""
- result = obj.view(type(self))
- result._update_from(self)
+ if obj is self: # for in-place operations
+ result = obj
+ else:
+ result = obj.view(type(self))
+ result._update_from(self)
if context is not None:
result._mask = result._mask.copy()
@@ -3017,7 +3020,7 @@ class MaskedArray(ndarray):
except KeyError:
# Domain not recognized, use fill_value instead
fill_value = self.fill_value
- result = result.copy()
+
np.copyto(result, fill_value, where=d)
# Update the mask
@@ -3028,7 +3031,7 @@ class MaskedArray(ndarray):
m = (m | d)
# Make sure the mask has the proper size
- if result.shape == () and m:
+ if result is not self and result.shape == () and m:
return masked
else:
result._mask = m
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index c6f276211..2b491b47e 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4494,6 +4494,15 @@ def test_default_fill_value_complex():
# regression test for Python 3, where 'unicode' was not defined
assert_(default_fill_value(1 + 1j) == 1.e20 + 0.0j)
+
+def test_ufunc_with_output():
+ # check that giving an output argument always returns that output.
+ # Regression test for gh-8416.
+ x = array([1., 2., 3.], mask=[0, 0, 1])
+ y = np.add(x, 1., out=x)
+ assert_(y is x)
+
+
###############################################################################
if __name__ == "__main__":
run_module_suite()