summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py76
-rw-r--r--numpy/ma/tests/test_core.py49
-rw-r--r--numpy/ma/tests/test_subclassing.py20
3 files changed, 126 insertions, 19 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 3ce44caf7..204f0cd06 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3183,29 +3183,78 @@ class MaskedArray(ndarray):
Return the item described by i, as a masked array.
"""
- dout = self.data[indx]
# We could directly use ndarray.__getitem__ on self.
# But then we would have to modify __array_finalize__ to prevent the
# mask of being reshaped if it hasn't been set up properly yet
# So it's easier to stick to the current version
+ dout = self.data[indx]
_mask = self._mask
+
+ def _is_scalar(m):
+ return not isinstance(m, np.ndarray)
+
+ def _scalar_heuristic(arr, elem):
+ """
+ Return whether `elem` is a scalar result of indexing `arr`, or None
+ if undecidable without promoting nomask to a full mask
+ """
+ # obviously a scalar
+ if not isinstance(elem, np.ndarray):
+ return True
+
+ # object array scalar indexing can return anything
+ elif arr.dtype.type is np.object_:
+ if arr.dtype is not elem.dtype:
+ # elem is an array, but dtypes do not match, so must be
+ # an element
+ return True
+
+ # well-behaved subclass that only returns 0d arrays when
+ # expected - this is not a scalar
+ elif type(arr).__getitem__ == ndarray.__getitem__:
+ return False
+
+ return None
+
+ if _mask is not nomask:
+ # _mask cannot be a subclass, so it tells us whether we should
+ # expect a scalar. It also cannot be of dtype object.
+ mout = _mask[indx]
+ scalar_expected = _is_scalar(mout)
+
+ else:
+ # attempt to apply the heuristic to avoid constructing a full mask
+ mout = nomask
+ scalar_expected = _scalar_heuristic(self.data, dout)
+ if scalar_expected is None:
+ # heuristics have failed
+ # construct a full array, so we can be certain. This is costly.
+ # we could also fall back on ndarray.__getitem__(self.data, indx)
+ scalar_expected = _is_scalar(getmaskarray(self)[indx])
+
# Did we extract a single item?
- if not getattr(dout, 'ndim', False):
+ if scalar_expected:
# A record
if isinstance(dout, np.void):
- mask = _mask[indx]
# We should always re-cast to mvoid, otherwise users can
# change masks on rows that already have masked values, but not
# on rows that have no masked values, which is inconsistent.
- dout = mvoid(dout, mask=mask, hardmask=self._hardmask)
+ return mvoid(dout, mask=mout, hardmask=self._hardmask)
+
+ # special case introduced in gh-5962
+ elif self.dtype.type is np.object_ and isinstance(dout, np.ndarray):
+ # If masked, turn into a MaskedArray, with everything masked.
+ if mout:
+ return MaskedArray(dout, mask=True)
+ else:
+ return dout
+
# Just a scalar
- elif _mask is not nomask and _mask[indx]:
- return masked
- elif self.dtype.type is np.object_ and self.dtype is not dout.dtype:
- # self contains an object array of arrays (yes, that happens).
- # If masked, turn into a MaskedArray, with everything masked.
- if _mask is not nomask and _mask[indx]:
- return MaskedArray(dout, mask=True)
+ else:
+ if mout:
+ return masked
+ else:
+ return dout
else:
# Force dout to MA
dout = dout.view(type(self))
@@ -3238,10 +3287,9 @@ class MaskedArray(ndarray):
dout._fill_value = dout._fill_value.flat[0]
dout._isfield = True
# Update the mask if needed
- if _mask is not nomask:
- dout._mask = _mask[indx]
+ if mout is not nomask:
# set shape to match that of data; this is needed for matrices
- dout._mask.shape = dout.shape
+ dout._mask = reshape(mout, dout.shape)
dout._sharedmask = True
# Note: Don't try to check for m.any(), that'll take too long
return dout
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index d6a7492ca..b08b76826 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4376,16 +4376,27 @@ class TestMaskedFields(TestCase):
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0])),
dtype=[('a', bool), ('b', bool)])
# No mask
- self.assertTrue(isinstance(a[1], MaskedArray))
+ assert_equal(type(a[1]), mvoid)
+ assert_equal(type(a[1,...]), MaskedArray)
+
# One element masked
- self.assertTrue(isinstance(a[0], MaskedArray))
+ assert_equal(type(a[0]), mvoid)
assert_equal_records(a[0]._data, a._data[0])
assert_equal_records(a[0]._mask, a._mask[0])
+
+ assert_equal(type(a[0,...]), MaskedArray)
+ assert_equal_records(a[0,...]._data, a._data[0,...])
+ assert_equal_records(a[0,...]._mask, a._mask[0,...])
+
# All element masked
- self.assertTrue(isinstance(a[-2], MaskedArray))
+ assert_equal(type(a[-2]), mvoid)
assert_equal_records(a[-2]._data, a._data[-2])
assert_equal_records(a[-2]._mask, a._mask[-2])
+ assert_equal(type(a[-2, ...]), MaskedArray)
+ assert_equal_records(a[-2, ...]._data, a._data[-2, ...])
+ assert_equal_records(a[-2, ...]._mask, a._mask[-2, ...])
+
def test_setitem(self):
# Issue 4866: check that one can set individual items in [record][col]
# and [col][record] order
@@ -4429,6 +4440,38 @@ class TestMaskedFields(TestCase):
assert_equal(len(rec), len(self.data['ddtype']))
+class TestMaskedObjectArray(TestCase):
+
+ def test_getitem(self):
+ arr = np.ma.array([None, None])
+ for dt in [float, object]:
+ a0 = np.eye(2).astype(dt)
+ a1 = np.eye(3).astype(dt)
+ arr[0] = a0
+ arr[1] = a1
+
+ assert_(arr[0] is a0)
+ assert_(arr[1] is a1)
+ assert_(isinstance(arr[0,...], MaskedArray))
+ assert_(isinstance(arr[1,...], MaskedArray))
+ assert_(arr[0,...][()] is a0)
+ assert_(arr[1,...][()] is a1)
+
+ arr[0] = np.ma.masked
+
+ assert_(arr[1] is a1)
+ assert_(isinstance(arr[0,...], MaskedArray))
+ assert_(isinstance(arr[1,...], MaskedArray))
+ assert_equal(arr[0,...].mask, True)
+ assert_(arr[1,...][()] is a1)
+
+ # gh-5962 - object arrays of arrays do something special
+ assert_equal(arr[0].data, a0)
+ assert_equal(arr[0].mask, True)
+ assert_equal(arr[0,...][()].data, a0)
+ assert_equal(arr[0,...][()].mask, True)
+
+
class TestMaskedView(TestCase):
def setUp(self):
diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py
index 8198c9d35..b2995fd57 100644
--- a/numpy/ma/tests/test_subclassing.py
+++ b/numpy/ma/tests/test_subclassing.py
@@ -9,7 +9,7 @@
from __future__ import division, absolute_import, print_function
import numpy as np
-from numpy.testing import TestCase, run_module_suite, assert_raises
+from numpy.testing import TestCase, run_module_suite, assert_raises, dec
from numpy.ma.testutils import assert_equal
from numpy.ma.core import (
array, arange, masked, MaskedArray, masked_array, log, add, hypot,
@@ -291,14 +291,19 @@ class TestSubclassing(TestCase):
# getter should return a ComplicatedSubArray, even for single item
# first check we wrote ComplicatedSubArray correctly
self.assertTrue(isinstance(xcsub[1], ComplicatedSubArray))
+ self.assertTrue(isinstance(xcsub[1,...], ComplicatedSubArray))
self.assertTrue(isinstance(xcsub[1:4], ComplicatedSubArray))
+
# now that it propagates inside the MaskedArray
self.assertTrue(isinstance(mxcsub[1], ComplicatedSubArray))
+ self.assertTrue(isinstance(mxcsub[1,...].data, ComplicatedSubArray))
self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(isinstance(mxcsub[0,...].data, ComplicatedSubArray))
self.assertTrue(isinstance(mxcsub[1:4].data, ComplicatedSubArray))
+
# also for flattened version (which goes via MaskedIterator)
self.assertTrue(isinstance(mxcsub.flat[1].data, ComplicatedSubArray))
- self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(mxcsub.flat[0] is masked)
self.assertTrue(isinstance(mxcsub.flat[1:4].base, ComplicatedSubArray))
# setter should only work with ComplicatedSubArray input
@@ -315,6 +320,17 @@ class TestSubclassing(TestCase):
mxcsub.flat[1] = xcsub[4]
mxcsub.flat[1:4] = xcsub[1:4]
+ def test_subclass_nomask_items(self):
+ x = np.arange(5)
+ xcsub = ComplicatedSubArray(x)
+ mxcsub_nomask = masked_array(xcsub)
+
+ self.assertTrue(isinstance(mxcsub_nomask[1,...].data, ComplicatedSubArray))
+ self.assertTrue(isinstance(mxcsub_nomask[0,...].data, ComplicatedSubArray))
+
+ self.assertTrue(isinstance(mxcsub_nomask[1], ComplicatedSubArray))
+ self.assertTrue(isinstance(mxcsub_nomask[0], ComplicatedSubArray))
+
def test_subclass_repr(self):
"""test that repr uses the name of the subclass
and 'array' for np.ndarray"""