diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2017-01-21 20:26:40 -0500 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2017-02-01 14:39:40 -0500 |
commit | 7927b3cb4b57e79bf6b1b8a723583295fc0398eb (patch) | |
tree | 74d8d82dd0b2cb29b6915e38a0142439691719b1 | |
parent | bd0fe0aa2159fa1d4570b0c524097b7492f239bf (diff) | |
download | numpy-7927b3cb4b57e79bf6b1b8a723583295fc0398eb.tar.gz |
BUG Ensure that operations on MaskedArray return any output given.
Without this commit, after an operation like
```
a = np.ma.MaskedArray([1., 2., 3.])
b = np.add(a, 1., out=a)
```
one would have `b is not a`. With this PR, it is guaranteed that
`b is a`.
-rw-r--r-- | numpy/ma/core.py | 11 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 9 |
2 files changed, 16 insertions, 4 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 49c790d4f..69c401e01 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -2988,8 +2988,11 @@ class MaskedArray(ndarray): Wraps the numpy array and sets the mask according to context. """ - result = obj.view(type(self)) - result._update_from(self) + if obj is self: # for in-place operations + result = obj + else: + result = obj.view(type(self)) + result._update_from(self) if context is not None: result._mask = result._mask.copy() @@ -3017,7 +3020,7 @@ class MaskedArray(ndarray): except KeyError: # Domain not recognized, use fill_value instead fill_value = self.fill_value - result = result.copy() + np.copyto(result, fill_value, where=d) # Update the mask @@ -3028,7 +3031,7 @@ class MaskedArray(ndarray): m = (m | d) # Make sure the mask has the proper size - if result.shape == () and m: + if result is not self and result.shape == () and m: return masked else: result._mask = m diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index c6f276211..2b491b47e 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4494,6 +4494,15 @@ def test_default_fill_value_complex(): # regression test for Python 3, where 'unicode' was not defined assert_(default_fill_value(1 + 1j) == 1.e20 + 0.0j) + +def test_ufunc_with_output(): + # check that giving an output argument always returns that output. + # Regression test for gh-8416. + x = array([1., 2., 3.], mask=[0, 0, 1]) + y = np.add(x, 1., out=x) + assert_(y is x) + + ############################################################################### if __name__ == "__main__": run_module_suite() |