diff options
author | ahaldane <ealloc@gmail.com> | 2017-03-06 11:58:00 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-06 11:58:00 -0500 |
commit | 66f1b8a5411e8265ed87ff7bf97c67960af2624d (patch) | |
tree | 712f692657a34c970e5fae3669ff2267711193b0 | |
parent | df9e134d1e40de4d3865597a239c5a1d68624f5d (diff) | |
parent | b1a70570c2110734ed6a371df2224a34aa20b9ec (diff) | |
download | numpy-66f1b8a5411e8265ed87ff7bf97c67960af2624d.tar.gz |
Merge pull request #8667 from eric-wieser/reuse-ma-dtype
BUG: Preserve identity of dtypes in make_mask_descr
-rw-r--r-- | numpy/ma/core.py | 60 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 20 |
2 files changed, 60 insertions, 20 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index ea4a1d85f..5ef4e1369 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -1286,25 +1286,50 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1) ############################################################################### -def _recursive_make_descr(datatype, newtype=bool_): - "Private function allowing recursion in make_descr." +def _replace_dtype_fields_recursive(dtype, primitive_dtype): + "Private function allowing recursion in _replace_dtype_fields." + _recurse = _replace_dtype_fields_recursive + # Do we have some name fields ? - if datatype.names: + if dtype.names: descr = [] - for name in datatype.names: - field = datatype.fields[name] + for name in dtype.names: + field = dtype.fields[name] if len(field) == 3: # Prepend the title to the name name = (field[-1], name) - descr.append((name, _recursive_make_descr(field[0], newtype))) - return descr + descr.append((name, _recurse(field[0], primitive_dtype))) + new_dtype = np.dtype(descr) + # Is this some kind of composite a la (np.float,2) - elif datatype.subdtype: - mdescr = list(datatype.subdtype) - mdescr[0] = _recursive_make_descr(datatype.subdtype[0], newtype) - return tuple(mdescr) + elif dtype.subdtype: + descr = list(dtype.subdtype) + descr[0] = _recurse(dtype.subdtype[0], primitive_dtype) + new_dtype = np.dtype(tuple(descr)) + + # this is a primitive type, so do a direct replacement else: - return newtype + new_dtype = primitive_dtype + + # preserve identity of dtypes + if new_dtype == dtype: + new_dtype = dtype + + return new_dtype + + +def _replace_dtype_fields(dtype, primitive_dtype): + """ + Construct a dtype description list from a given dtype. + + Returns a new dtype object, with all fields and subtypes in the given type + recursively replaced with `primitive_dtype`. + + Arguments are coerced to dtypes first. + """ + dtype = np.dtype(dtype) + primitive_dtype = np.dtype(primitive_dtype) + return _replace_dtype_fields_recursive(dtype, primitive_dtype) def make_mask_descr(ndtype): @@ -1334,13 +1359,10 @@ def make_mask_descr(ndtype): >>> ma.make_mask_descr(dtype) dtype([('foo', '|b1'), ('bar', '|b1')]) >>> ma.make_mask_descr(np.float32) - <type 'numpy.bool_'> + dtype('bool') """ - # Make sure we do have a dtype - if not isinstance(ndtype, np.dtype): - ndtype = np.dtype(ndtype) - return np.dtype(_recursive_make_descr(ndtype, np.bool)) + return _replace_dtype_fields(ndtype, MaskType) def getmask(a): @@ -3826,7 +3848,7 @@ class MaskedArray(ndarray): res = data.astype("O") res.view(ndarray)[mask] = f else: - rdtype = _recursive_make_descr(self.dtype, "O") + rdtype = _replace_dtype_fields(self.dtype, "O") res = self._data.astype(rdtype) _recursive_printoption(res, m, f) else: @@ -5972,7 +5994,7 @@ class mvoid(MaskedArray): if m is nomask: return self._data.__str__() printopt = masked_print_option - rdtype = _recursive_make_descr(self._data.dtype, "O") + rdtype = _replace_dtype_fields(self._data.dtype, "O") # temporary hack to fix gh-7493. A more permanent fix # is proposed in gh-6053, after which the next two diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index ca1ef16c4..93898c4d0 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -4000,32 +4000,50 @@ class TestMaskedArrayFunctions(TestCase): self.assertTrue(c.flags['C']) def test_make_mask_descr(self): - # Test make_mask_descr # Flexible ntype = [('a', np.float), ('b', np.float)] test = make_mask_descr(ntype) assert_equal(test, [('a', np.bool), ('b', np.bool)]) + assert_(test is make_mask_descr(test)) + # Standard w/ shape ntype = (np.float, 2) test = make_mask_descr(ntype) assert_equal(test, (np.bool, 2)) + assert_(test is make_mask_descr(test)) + # Standard standard ntype = np.float test = make_mask_descr(ntype) assert_equal(test, np.dtype(np.bool)) + assert_(test is make_mask_descr(test)) + # Nested ntype = [('a', np.float), ('b', [('ba', np.float), ('bb', np.float)])] test = make_mask_descr(ntype) control = np.dtype([('a', 'b1'), ('b', [('ba', 'b1'), ('bb', 'b1')])]) assert_equal(test, control) + assert_(test is make_mask_descr(test)) + # Named+ shape ntype = [('a', (np.float, 2))] test = make_mask_descr(ntype) assert_equal(test, np.dtype([('a', (np.bool, 2))])) + assert_(test is make_mask_descr(test)) + # 2 names ntype = [(('A', 'a'), float)] test = make_mask_descr(ntype) assert_equal(test, np.dtype([(('A', 'a'), bool)])) + assert_(test is make_mask_descr(test)) + + # nested boolean types should preserve identity + base_type = np.dtype([('a', int, 3)]) + base_mtype = make_mask_descr(base_type) + sub_type = np.dtype([('a', int), ('b', base_mtype)]) + test = make_mask_descr(sub_type) + assert_equal(test, np.dtype([('a', bool), ('b', [('a', bool, 3)])])) + assert_(test.fields['b'][0] is base_mtype) def test_make_mask(self): # Test make_mask |