diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2014-04-04 11:47:40 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2015-04-22 19:06:52 -0400 |
commit | 7a84c5660539bb210746ba6b9b8e38d82d9fd330 (patch) | |
tree | a7d02e08b4a9ae144addeb051845782ff60e6fa8 | |
parent | 02b858326dac217607a83ed0bf4d7d51d5bfbfbe (diff) | |
download | numpy-7a84c5660539bb210746ba6b9b8e38d82d9fd330.tar.gz |
ENH: Let MaskedArray getter, setter respect baseclass overrides
-rw-r--r-- | doc/release/1.10.0-notes.rst | 6 | ||||
-rw-r--r-- | numpy/ma/core.py | 27 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 43 | ||||
-rw-r--r-- | numpy/ma/tests/test_subclassing.py | 93 |
4 files changed, 151 insertions, 18 deletions
diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst index a7c0e2852..85bfd11dc 100644 --- a/doc/release/1.10.0-notes.rst +++ b/doc/release/1.10.0-notes.rst @@ -207,6 +207,12 @@ arguments for controlling backward compatibility of pickled Python objects. This enables Numpy on Python 3 to load npy files containing object arrays that were generated on Python 2. +MaskedArray support for more complicated base classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Built-in assumptions that the baseclass behaved like a plain array are being +removed. In particalur, setting and getting elements and ranges will respect +baseclass overrides of ``__setitem__`` and ``__getitem__``. + Changes ======= diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 51e9f0f28..ee56e97e2 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3043,8 +3043,7 @@ class MaskedArray(ndarray): # if getmask(indx) is not nomask: # msg = "Masked arrays must be filled before they can be used as indices!" # raise IndexError(msg) - _data = ndarray.view(self, ndarray) - dout = ndarray.__getitem__(_data, indx) + dout = self.data[indx] # We could directly use ndarray.__getitem__ on self... # But then we would have to modify __array_finalize__ to prevent the # mask of being reshaped if it hasn't been set up properly yet... @@ -3074,6 +3073,8 @@ class MaskedArray(ndarray): # Update the mask if needed if _mask is not nomask: dout._mask = _mask[indx] + # set shape to match that of data; this is needed for matrices + dout._mask.shape = dout.shape dout._sharedmask = True # Note: Don't try to check for m.any(), that'll take too long... return dout @@ -3091,16 +3092,16 @@ class MaskedArray(ndarray): # if getmask(indx) is not nomask: # msg = "Masked arrays must be filled before they can be used as indices!" # raise IndexError(msg) - _data = ndarray.view(self, ndarray.__getattribute__(self, '_baseclass')) - _mask = ndarray.__getattribute__(self, '_mask') + _data = self._data + _mask = self._mask if isinstance(indx, basestring): - ndarray.__setitem__(_data, indx, value) + _data[indx] = value if _mask is nomask: self._mask = _mask = make_mask_none(self.shape, self.dtype) _mask[indx] = getmask(value) return #........................................ - _dtype = ndarray.__getattribute__(_data, 'dtype') + _dtype = _data.dtype nbfields = len(_dtype.names or ()) #........................................ if value is masked: @@ -3124,21 +3125,21 @@ class MaskedArray(ndarray): mval = tuple([False] * nbfields) if _mask is nomask: # Set the data, then the mask - ndarray.__setitem__(_data, indx, dval) + _data[indx] = dval if mval is not nomask: _mask = self._mask = make_mask_none(self.shape, _dtype) - ndarray.__setitem__(_mask, indx, mval) + _mask[indx] = mval elif not self._hardmask: # Unshare the mask if necessary to avoid propagation if not self._isfield: self.unshare_mask() - _mask = ndarray.__getattribute__(self, '_mask') + _mask = self._mask # Set the data, then the mask - ndarray.__setitem__(_data, indx, dval) - ndarray.__setitem__(_mask, indx, mval) + _data[indx] = dval + _mask[indx] = mval elif hasattr(indx, 'dtype') and (indx.dtype == MaskType): indx = indx * umath.logical_not(_mask) - ndarray.__setitem__(_data, indx, dval) + _data[indx] = dval else: if nbfields: err_msg = "Flexible 'hard' masks are not yet supported..." @@ -3149,7 +3150,7 @@ class MaskedArray(ndarray): np.copyto(dindx, dval, where=~mindx) elif mindx is nomask: dindx = dval - ndarray.__setitem__(_data, indx, dindx) + _data[indx] = dindx _mask[indx] = mindx return diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 807fc0ba6..32280b7de 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -275,6 +275,49 @@ class TestMaskedArray(TestCase): assert_equal(s1, s2) assert_(x1[1:1].shape == (0,)) + def test_matrix_indexing(self): + # Tests conversions and indexing + x1 = np.matrix([[1, 2, 3], [4, 3, 2]]) + x2 = array(x1, mask=[[1, 0, 0], [0, 1, 0]]) + x3 = array(x1, mask=[[0, 1, 0], [1, 0, 0]]) + x4 = array(x1) + # test conversion to strings + junk, garbage = str(x2), repr(x2) + # assert_equal(np.sort(x1), sort(x2, endwith=False)) + # tests of indexing + assert_(type(x2[1, 0]) is type(x1[1, 0])) + assert_(x1[1, 0] == x2[1, 0]) + assert_(x2[1, 1] is masked) + assert_equal(x1[0, 2], x2[0, 2]) + assert_equal(x1[0, 1:], x2[0, 1:]) + assert_equal(x1[:, 2], x2[:, 2]) + assert_equal(x1[:], x2[:]) + assert_equal(x1[1:], x3[1:]) + x1[0, 2] = 9 + x2[0, 2] = 9 + assert_equal(x1, x2) + x1[0, 1:] = 99 + x2[0, 1:] = 99 + assert_equal(x1, x2) + x2[0, 1] = masked + assert_equal(x1, x2) + x2[0, 1:] = masked + assert_equal(x1, x2) + x2[0, :] = x1[0, :] + x2[0, 1] = masked + assert_(allequal(getmask(x2), np.array([[0, 1, 0], [0, 1, 0]]))) + x3[1, :] = masked_array([1, 2, 3], [1, 1, 0]) + assert_(allequal(getmask(x3)[1], array([1, 1, 0]))) + assert_(allequal(getmask(x3[1]), array([1, 1, 0]))) + x4[1, :] = masked_array([1, 2, 3], [1, 1, 0]) + assert_(allequal(getmask(x4[1]), array([1, 1, 0]))) + assert_(allequal(x4[1], array([1, 2, 3]))) + x1 = np.matrix(np.arange(5) * 1.0) + x2 = masked_values(x1, 3.0) + assert_equal(x1, x2) + assert_(allequal(array([0, 0, 0, 1, 0], MaskType), x2.mask)) + assert_equal(3.0, x2.fill_value) + def test_copy(self): # Tests of some subtle points of copying and sizing. n = [0, 0, 1, 0, 0] diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py index ade5c59da..07fc8fdd6 100644 --- a/numpy/ma/tests/test_subclassing.py +++ b/numpy/ma/tests/test_subclassing.py @@ -84,20 +84,71 @@ mmatrix = MMatrix # also a subclass that overrides __str__, __repr__ and __setitem__, disallowing # setting to non-class values (and thus np.ma.core.masked_print_option) +class CSAIterator(object): + """ + Flat iterator object that uses its own setter/getter + (works around ndarray.flat not propagating subclass setters/getters + see https://github.com/numpy/numpy/issues/4564) + roughly following MaskedIterator + """ + def __init__(self, a): + self._original = a + self._dataiter = a.view(np.ndarray).flat + + def __iter__(self): + return self + + def __getitem__(self, indx): + out = self._dataiter.__getitem__(indx) + if not isinstance(out, np.ndarray): + out = out.__array__() + out = out.view(type(self._original)) + return out + + def __setitem__(self, index, value): + self._dataiter[index] = self._original._validate_input(value) + + def __next__(self): + return next(self._dataiter).__array__().view(type(self._original)) + + next = __next__ + + class ComplicatedSubArray(SubArray): + def __str__(self): - return 'myprefix {0} mypostfix'.format( - super(ComplicatedSubArray, self).__str__()) + return 'myprefix {0} mypostfix'.format(self.view(SubArray)) def __repr__(self): # Return a repr that does not start with 'name(' return '<{0} {1}>'.format(self.__class__.__name__, self) - def __setitem__(self, item, value): - # this ensures direct assignment to masked_print_option will fail + def _validate_input(self, value): if not isinstance(value, ComplicatedSubArray): raise ValueError("Can only set to MySubArray values") - super(ComplicatedSubArray, self).__setitem__(item, value) + return value + + def __setitem__(self, item, value): + # validation ensures direct assignment with ndarray or + # masked_print_option will fail + super(ComplicatedSubArray, self).__setitem__( + item, self._validate_input(value)) + + def __getitem__(self, item): + # ensure getter returns our own class also for scalars + value = super(ComplicatedSubArray, self).__getitem__(item) + if not isinstance(value, np.ndarray): # scalar + value = value.__array__().view(ComplicatedSubArray) + return value + + @property + def flat(self): + return CSAIterator(self) + + @flat.setter + def flat(self, value): + y = self.ravel() + y[:] = value class TestSubclassing(TestCase): @@ -205,6 +256,38 @@ class TestSubclassing(TestCase): assert_equal(mxsub.info, xsub.info) assert_equal(mxsub._mask, m) + def test_subclass_items(self): + """test that getter and setter go via baseclass""" + x = np.arange(5) + xcsub = ComplicatedSubArray(x) + mxcsub = masked_array(xcsub, mask=[True, False, True, False, False]) + # getter should return a ComplicatedSubArray, even for single item + # first check we wrote ComplicatedSubArray correctly + self.assertTrue(isinstance(xcsub[1], ComplicatedSubArray)) + self.assertTrue(isinstance(xcsub[1:4], ComplicatedSubArray)) + # now that it propagates inside the MaskedArray + self.assertTrue(isinstance(mxcsub[1], ComplicatedSubArray)) + self.assertTrue(mxcsub[0] is masked) + self.assertTrue(isinstance(mxcsub[1:4].data, ComplicatedSubArray)) + # also for flattened version (which goes via MaskedIterator) + self.assertTrue(isinstance(mxcsub.flat[1].data, ComplicatedSubArray)) + self.assertTrue(mxcsub[0] is masked) + self.assertTrue(isinstance(mxcsub.flat[1:4].base, ComplicatedSubArray)) + + # setter should only work with ComplicatedSubArray input + # first check we wrote ComplicatedSubArray correctly + assert_raises(ValueError, xcsub.__setitem__, 1, x[4]) + # now that it propagates inside the MaskedArray + assert_raises(ValueError, mxcsub.__setitem__, 1, x[4]) + assert_raises(ValueError, mxcsub.__setitem__, slice(1, 4), x[1:4]) + mxcsub[1] = xcsub[4] + mxcsub[1:4] = xcsub[1:4] + # also for flattened version (which goes via MaskedIterator) + assert_raises(ValueError, mxcsub.flat.__setitem__, 1, x[4]) + assert_raises(ValueError, mxcsub.flat.__setitem__, slice(1, 4), x[1:4]) + mxcsub.flat[1] = xcsub[4] + mxcsub.flat[1:4] = xcsub[1:4] + def test_subclass_repr(self): """test that repr uses the name of the subclass and 'array' for np.ndarray""" |