summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorahaldane <ealloc@gmail.com>2017-03-06 11:58:00 -0500
committerGitHub <noreply@github.com>2017-03-06 11:58:00 -0500
commit66f1b8a5411e8265ed87ff7bf97c67960af2624d (patch)
tree712f692657a34c970e5fae3669ff2267711193b0
parentdf9e134d1e40de4d3865597a239c5a1d68624f5d (diff)
parentb1a70570c2110734ed6a371df2224a34aa20b9ec (diff)
downloadnumpy-66f1b8a5411e8265ed87ff7bf97c67960af2624d.tar.gz
Merge pull request #8667 from eric-wieser/reuse-ma-dtype
BUG: Preserve identity of dtypes in make_mask_descr
-rw-r--r--numpy/ma/core.py60
-rw-r--r--numpy/ma/tests/test_core.py20
2 files changed, 60 insertions, 20 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index ea4a1d85f..5ef4e1369 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -1286,25 +1286,50 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1)
###############################################################################
-def _recursive_make_descr(datatype, newtype=bool_):
- "Private function allowing recursion in make_descr."
+def _replace_dtype_fields_recursive(dtype, primitive_dtype):
+ "Private function allowing recursion in _replace_dtype_fields."
+ _recurse = _replace_dtype_fields_recursive
+
# Do we have some name fields ?
- if datatype.names:
+ if dtype.names:
descr = []
- for name in datatype.names:
- field = datatype.fields[name]
+ for name in dtype.names:
+ field = dtype.fields[name]
if len(field) == 3:
# Prepend the title to the name
name = (field[-1], name)
- descr.append((name, _recursive_make_descr(field[0], newtype)))
- return descr
+ descr.append((name, _recurse(field[0], primitive_dtype)))
+ new_dtype = np.dtype(descr)
+
# Is this some kind of composite a la (np.float,2)
- elif datatype.subdtype:
- mdescr = list(datatype.subdtype)
- mdescr[0] = _recursive_make_descr(datatype.subdtype[0], newtype)
- return tuple(mdescr)
+ elif dtype.subdtype:
+ descr = list(dtype.subdtype)
+ descr[0] = _recurse(dtype.subdtype[0], primitive_dtype)
+ new_dtype = np.dtype(tuple(descr))
+
+ # this is a primitive type, so do a direct replacement
else:
- return newtype
+ new_dtype = primitive_dtype
+
+ # preserve identity of dtypes
+ if new_dtype == dtype:
+ new_dtype = dtype
+
+ return new_dtype
+
+
+def _replace_dtype_fields(dtype, primitive_dtype):
+ """
+ Construct a dtype description list from a given dtype.
+
+ Returns a new dtype object, with all fields and subtypes in the given type
+ recursively replaced with `primitive_dtype`.
+
+ Arguments are coerced to dtypes first.
+ """
+ dtype = np.dtype(dtype)
+ primitive_dtype = np.dtype(primitive_dtype)
+ return _replace_dtype_fields_recursive(dtype, primitive_dtype)
def make_mask_descr(ndtype):
@@ -1334,13 +1359,10 @@ def make_mask_descr(ndtype):
>>> ma.make_mask_descr(dtype)
dtype([('foo', '|b1'), ('bar', '|b1')])
>>> ma.make_mask_descr(np.float32)
- <type 'numpy.bool_'>
+ dtype('bool')
"""
- # Make sure we do have a dtype
- if not isinstance(ndtype, np.dtype):
- ndtype = np.dtype(ndtype)
- return np.dtype(_recursive_make_descr(ndtype, np.bool))
+ return _replace_dtype_fields(ndtype, MaskType)
def getmask(a):
@@ -3826,7 +3848,7 @@ class MaskedArray(ndarray):
res = data.astype("O")
res.view(ndarray)[mask] = f
else:
- rdtype = _recursive_make_descr(self.dtype, "O")
+ rdtype = _replace_dtype_fields(self.dtype, "O")
res = self._data.astype(rdtype)
_recursive_printoption(res, m, f)
else:
@@ -5972,7 +5994,7 @@ class mvoid(MaskedArray):
if m is nomask:
return self._data.__str__()
printopt = masked_print_option
- rdtype = _recursive_make_descr(self._data.dtype, "O")
+ rdtype = _replace_dtype_fields(self._data.dtype, "O")
# temporary hack to fix gh-7493. A more permanent fix
# is proposed in gh-6053, after which the next two
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index ca1ef16c4..93898c4d0 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -4000,32 +4000,50 @@ class TestMaskedArrayFunctions(TestCase):
self.assertTrue(c.flags['C'])
def test_make_mask_descr(self):
- # Test make_mask_descr
# Flexible
ntype = [('a', np.float), ('b', np.float)]
test = make_mask_descr(ntype)
assert_equal(test, [('a', np.bool), ('b', np.bool)])
+ assert_(test is make_mask_descr(test))
+
# Standard w/ shape
ntype = (np.float, 2)
test = make_mask_descr(ntype)
assert_equal(test, (np.bool, 2))
+ assert_(test is make_mask_descr(test))
+
# Standard standard
ntype = np.float
test = make_mask_descr(ntype)
assert_equal(test, np.dtype(np.bool))
+ assert_(test is make_mask_descr(test))
+
# Nested
ntype = [('a', np.float), ('b', [('ba', np.float), ('bb', np.float)])]
test = make_mask_descr(ntype)
control = np.dtype([('a', 'b1'), ('b', [('ba', 'b1'), ('bb', 'b1')])])
assert_equal(test, control)
+ assert_(test is make_mask_descr(test))
+
# Named+ shape
ntype = [('a', (np.float, 2))]
test = make_mask_descr(ntype)
assert_equal(test, np.dtype([('a', (np.bool, 2))]))
+ assert_(test is make_mask_descr(test))
+
# 2 names
ntype = [(('A', 'a'), float)]
test = make_mask_descr(ntype)
assert_equal(test, np.dtype([(('A', 'a'), bool)]))
+ assert_(test is make_mask_descr(test))
+
+ # nested boolean types should preserve identity
+ base_type = np.dtype([('a', int, 3)])
+ base_mtype = make_mask_descr(base_type)
+ sub_type = np.dtype([('a', int), ('b', base_mtype)])
+ test = make_mask_descr(sub_type)
+ assert_equal(test, np.dtype([('a', bool), ('b', [('a', bool, 3)])]))
+ assert_(test.fields['b'][0] is base_mtype)
def test_make_mask(self):
# Test make_mask