summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-06-03 18:48:38 -0600
committerCharles Harris <charlesr.harris@gmail.com>2015-06-03 18:48:38 -0600
commit36b940497f194b1b90e9ff48c58a01a6e74c09de (patch)
treec498667a93d5c18635bef33dd6cc0cb944c4c230
parenta9c810dd1d8fc1e3c6d0f0ca6310f41795545ec9 (diff)
parent7a84c5660539bb210746ba6b9b8e38d82d9fd330 (diff)
downloadnumpy-36b940497f194b1b90e9ff48c58a01a6e74c09de.tar.gz
Merge pull request #4586 from mhvk/ma/subclass-item-setting-getting
ENH: Let MaskedArray getter, setter respect baseclass overrides
-rw-r--r--doc/release/1.10.0-notes.rst6
-rw-r--r--numpy/ma/core.py27
-rw-r--r--numpy/ma/tests/test_core.py43
-rw-r--r--numpy/ma/tests/test_subclassing.py93
4 files changed, 151 insertions, 18 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst
index c4ff2e4b4..f2ab6e99a 100644
--- a/doc/release/1.10.0-notes.rst
+++ b/doc/release/1.10.0-notes.rst
@@ -235,6 +235,12 @@ arguments for controlling backward compatibility of pickled Python
objects. This enables Numpy on Python 3 to load npy files containing
object arrays that were generated on Python 2.
+MaskedArray support for more complicated base classes
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Built-in assumptions that the baseclass behaved like a plain array are being
+removed. In particalur, setting and getting elements and ranges will respect
+baseclass overrides of ``__setitem__`` and ``__getitem__``.
+
Changes
=======
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 45fb8c98b..5df928a6d 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -3044,8 +3044,7 @@ class MaskedArray(ndarray):
# if getmask(indx) is not nomask:
# msg = "Masked arrays must be filled before they can be used as indices!"
# raise IndexError(msg)
- _data = ndarray.view(self, ndarray)
- dout = ndarray.__getitem__(_data, indx)
+ dout = self.data[indx]
# We could directly use ndarray.__getitem__ on self...
# But then we would have to modify __array_finalize__ to prevent the
# mask of being reshaped if it hasn't been set up properly yet...
@@ -3075,6 +3074,8 @@ class MaskedArray(ndarray):
# Update the mask if needed
if _mask is not nomask:
dout._mask = _mask[indx]
+ # set shape to match that of data; this is needed for matrices
+ dout._mask.shape = dout.shape
dout._sharedmask = True
# Note: Don't try to check for m.any(), that'll take too long...
return dout
@@ -3092,16 +3093,16 @@ class MaskedArray(ndarray):
# if getmask(indx) is not nomask:
# msg = "Masked arrays must be filled before they can be used as indices!"
# raise IndexError(msg)
- _data = ndarray.view(self, ndarray.__getattribute__(self, '_baseclass'))
- _mask = ndarray.__getattribute__(self, '_mask')
+ _data = self._data
+ _mask = self._mask
if isinstance(indx, basestring):
- ndarray.__setitem__(_data, indx, value)
+ _data[indx] = value
if _mask is nomask:
self._mask = _mask = make_mask_none(self.shape, self.dtype)
_mask[indx] = getmask(value)
return
#........................................
- _dtype = ndarray.__getattribute__(_data, 'dtype')
+ _dtype = _data.dtype
nbfields = len(_dtype.names or ())
#........................................
if value is masked:
@@ -3125,21 +3126,21 @@ class MaskedArray(ndarray):
mval = tuple([False] * nbfields)
if _mask is nomask:
# Set the data, then the mask
- ndarray.__setitem__(_data, indx, dval)
+ _data[indx] = dval
if mval is not nomask:
_mask = self._mask = make_mask_none(self.shape, _dtype)
- ndarray.__setitem__(_mask, indx, mval)
+ _mask[indx] = mval
elif not self._hardmask:
# Unshare the mask if necessary to avoid propagation
if not self._isfield:
self.unshare_mask()
- _mask = ndarray.__getattribute__(self, '_mask')
+ _mask = self._mask
# Set the data, then the mask
- ndarray.__setitem__(_data, indx, dval)
- ndarray.__setitem__(_mask, indx, mval)
+ _data[indx] = dval
+ _mask[indx] = mval
elif hasattr(indx, 'dtype') and (indx.dtype == MaskType):
indx = indx * umath.logical_not(_mask)
- ndarray.__setitem__(_data, indx, dval)
+ _data[indx] = dval
else:
if nbfields:
err_msg = "Flexible 'hard' masks are not yet supported..."
@@ -3150,7 +3151,7 @@ class MaskedArray(ndarray):
np.copyto(dindx, dval, where=~mindx)
elif mindx is nomask:
dindx = dval
- ndarray.__setitem__(_data, indx, dindx)
+ _data[indx] = dindx
_mask[indx] = mindx
return
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index f8a28164e..6df235e9f 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -275,6 +275,49 @@ class TestMaskedArray(TestCase):
assert_equal(s1, s2)
assert_(x1[1:1].shape == (0,))
+ def test_matrix_indexing(self):
+ # Tests conversions and indexing
+ x1 = np.matrix([[1, 2, 3], [4, 3, 2]])
+ x2 = array(x1, mask=[[1, 0, 0], [0, 1, 0]])
+ x3 = array(x1, mask=[[0, 1, 0], [1, 0, 0]])
+ x4 = array(x1)
+ # test conversion to strings
+ junk, garbage = str(x2), repr(x2)
+ # assert_equal(np.sort(x1), sort(x2, endwith=False))
+ # tests of indexing
+ assert_(type(x2[1, 0]) is type(x1[1, 0]))
+ assert_(x1[1, 0] == x2[1, 0])
+ assert_(x2[1, 1] is masked)
+ assert_equal(x1[0, 2], x2[0, 2])
+ assert_equal(x1[0, 1:], x2[0, 1:])
+ assert_equal(x1[:, 2], x2[:, 2])
+ assert_equal(x1[:], x2[:])
+ assert_equal(x1[1:], x3[1:])
+ x1[0, 2] = 9
+ x2[0, 2] = 9
+ assert_equal(x1, x2)
+ x1[0, 1:] = 99
+ x2[0, 1:] = 99
+ assert_equal(x1, x2)
+ x2[0, 1] = masked
+ assert_equal(x1, x2)
+ x2[0, 1:] = masked
+ assert_equal(x1, x2)
+ x2[0, :] = x1[0, :]
+ x2[0, 1] = masked
+ assert_(allequal(getmask(x2), np.array([[0, 1, 0], [0, 1, 0]])))
+ x3[1, :] = masked_array([1, 2, 3], [1, 1, 0])
+ assert_(allequal(getmask(x3)[1], array([1, 1, 0])))
+ assert_(allequal(getmask(x3[1]), array([1, 1, 0])))
+ x4[1, :] = masked_array([1, 2, 3], [1, 1, 0])
+ assert_(allequal(getmask(x4[1]), array([1, 1, 0])))
+ assert_(allequal(x4[1], array([1, 2, 3])))
+ x1 = np.matrix(np.arange(5) * 1.0)
+ x2 = masked_values(x1, 3.0)
+ assert_equal(x1, x2)
+ assert_(allequal(array([0, 0, 0, 1, 0], MaskType), x2.mask))
+ assert_equal(3.0, x2.fill_value)
+
def test_copy(self):
# Tests of some subtle points of copying and sizing.
n = [0, 0, 1, 0, 0]
diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py
index ade5c59da..07fc8fdd6 100644
--- a/numpy/ma/tests/test_subclassing.py
+++ b/numpy/ma/tests/test_subclassing.py
@@ -84,20 +84,71 @@ mmatrix = MMatrix
# also a subclass that overrides __str__, __repr__ and __setitem__, disallowing
# setting to non-class values (and thus np.ma.core.masked_print_option)
+class CSAIterator(object):
+ """
+ Flat iterator object that uses its own setter/getter
+ (works around ndarray.flat not propagating subclass setters/getters
+ see https://github.com/numpy/numpy/issues/4564)
+ roughly following MaskedIterator
+ """
+ def __init__(self, a):
+ self._original = a
+ self._dataiter = a.view(np.ndarray).flat
+
+ def __iter__(self):
+ return self
+
+ def __getitem__(self, indx):
+ out = self._dataiter.__getitem__(indx)
+ if not isinstance(out, np.ndarray):
+ out = out.__array__()
+ out = out.view(type(self._original))
+ return out
+
+ def __setitem__(self, index, value):
+ self._dataiter[index] = self._original._validate_input(value)
+
+ def __next__(self):
+ return next(self._dataiter).__array__().view(type(self._original))
+
+ next = __next__
+
+
class ComplicatedSubArray(SubArray):
+
def __str__(self):
- return 'myprefix {0} mypostfix'.format(
- super(ComplicatedSubArray, self).__str__())
+ return 'myprefix {0} mypostfix'.format(self.view(SubArray))
def __repr__(self):
# Return a repr that does not start with 'name('
return '<{0} {1}>'.format(self.__class__.__name__, self)
- def __setitem__(self, item, value):
- # this ensures direct assignment to masked_print_option will fail
+ def _validate_input(self, value):
if not isinstance(value, ComplicatedSubArray):
raise ValueError("Can only set to MySubArray values")
- super(ComplicatedSubArray, self).__setitem__(item, value)
+ return value
+
+ def __setitem__(self, item, value):
+ # validation ensures direct assignment with ndarray or
+ # masked_print_option will fail
+ super(ComplicatedSubArray, self).__setitem__(
+ item, self._validate_input(value))
+
+ def __getitem__(self, item):
+ # ensure getter returns our own class also for scalars
+ value = super(ComplicatedSubArray, self).__getitem__(item)
+ if not isinstance(value, np.ndarray): # scalar
+ value = value.__array__().view(ComplicatedSubArray)
+ return value
+
+ @property
+ def flat(self):
+ return CSAIterator(self)
+
+ @flat.setter
+ def flat(self, value):
+ y = self.ravel()
+ y[:] = value
class TestSubclassing(TestCase):
@@ -205,6 +256,38 @@ class TestSubclassing(TestCase):
assert_equal(mxsub.info, xsub.info)
assert_equal(mxsub._mask, m)
+ def test_subclass_items(self):
+ """test that getter and setter go via baseclass"""
+ x = np.arange(5)
+ xcsub = ComplicatedSubArray(x)
+ mxcsub = masked_array(xcsub, mask=[True, False, True, False, False])
+ # getter should return a ComplicatedSubArray, even for single item
+ # first check we wrote ComplicatedSubArray correctly
+ self.assertTrue(isinstance(xcsub[1], ComplicatedSubArray))
+ self.assertTrue(isinstance(xcsub[1:4], ComplicatedSubArray))
+ # now that it propagates inside the MaskedArray
+ self.assertTrue(isinstance(mxcsub[1], ComplicatedSubArray))
+ self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(isinstance(mxcsub[1:4].data, ComplicatedSubArray))
+ # also for flattened version (which goes via MaskedIterator)
+ self.assertTrue(isinstance(mxcsub.flat[1].data, ComplicatedSubArray))
+ self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(isinstance(mxcsub.flat[1:4].base, ComplicatedSubArray))
+
+ # setter should only work with ComplicatedSubArray input
+ # first check we wrote ComplicatedSubArray correctly
+ assert_raises(ValueError, xcsub.__setitem__, 1, x[4])
+ # now that it propagates inside the MaskedArray
+ assert_raises(ValueError, mxcsub.__setitem__, 1, x[4])
+ assert_raises(ValueError, mxcsub.__setitem__, slice(1, 4), x[1:4])
+ mxcsub[1] = xcsub[4]
+ mxcsub[1:4] = xcsub[1:4]
+ # also for flattened version (which goes via MaskedIterator)
+ assert_raises(ValueError, mxcsub.flat.__setitem__, 1, x[4])
+ assert_raises(ValueError, mxcsub.flat.__setitem__, slice(1, 4), x[1:4])
+ mxcsub.flat[1] = xcsub[4]
+ mxcsub.flat[1:4] = xcsub[1:4]
+
def test_subclass_repr(self):
"""test that repr uses the name of the subclass
and 'array' for np.ndarray"""