summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py59
1 files changed, 38 insertions, 21 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index ad67c582a..5ef4e1369 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -1286,32 +1286,52 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1)
###############################################################################
-def _recursive_make_descr(datatype, newtype=bool_):
- "Private function allowing recursion in make_descr."
+def _replace_dtype_fields_recursive(dtype, primitive_dtype):
+ "Private function allowing recursion in _replace_dtype_fields."
+ _recurse = _replace_dtype_fields_recursive
+
# Do we have some name fields ?
- if datatype.names:
+ if dtype.names:
descr = []
- for name in datatype.names:
- field = datatype.fields[name]
+ for name in dtype.names:
+ field = dtype.fields[name]
if len(field) == 3:
# Prepend the title to the name
name = (field[-1], name)
- descr.append((name, _recursive_make_descr(field[0], newtype)))
+ descr.append((name, _recurse(field[0], primitive_dtype)))
+ new_dtype = np.dtype(descr)
+
# Is this some kind of composite a la (np.float,2)
- elif datatype.subdtype:
- descr = list(datatype.subdtype)
- descr[0] = _recursive_make_descr(datatype.subdtype[0], newtype)
- descr = tuple(descr)
+ elif dtype.subdtype:
+ descr = list(dtype.subdtype)
+ descr[0] = _recurse(dtype.subdtype[0], primitive_dtype)
+ new_dtype = np.dtype(tuple(descr))
+
+ # this is a primitive type, so do a direct replacement
else:
- descr = newtype
+ new_dtype = primitive_dtype
- new_dtype = np.dtype(descr)
- if new_dtype == datatype:
- new_dtype = datatype
+ # preserve identity of dtypes
+ if new_dtype == dtype:
+ new_dtype = dtype
return new_dtype
+def _replace_dtype_fields(dtype, primitive_dtype):
+ """
+ Construct a dtype description list from a given dtype.
+
+ Returns a new dtype object, with all fields and subtypes in the given type
+ recursively replaced with `primitive_dtype`.
+
+ Arguments are coerced to dtypes first.
+ """
+ dtype = np.dtype(dtype)
+ primitive_dtype = np.dtype(primitive_dtype)
+ return _replace_dtype_fields_recursive(dtype, primitive_dtype)
+
+
def make_mask_descr(ndtype):
"""
Construct a dtype description list from a given dtype.
@@ -1339,13 +1359,10 @@ def make_mask_descr(ndtype):
>>> ma.make_mask_descr(dtype)
dtype([('foo', '|b1'), ('bar', '|b1')])
>>> ma.make_mask_descr(np.float32)
- <type 'numpy.bool_'>
+ dtype('bool')
"""
- # Make sure we do have a dtype
- if not isinstance(ndtype, np.dtype):
- ndtype = np.dtype(ndtype)
- return _recursive_make_descr(ndtype, np.bool)
+ return _replace_dtype_fields(ndtype, MaskType)
def getmask(a):
@@ -3831,7 +3848,7 @@ class MaskedArray(ndarray):
res = data.astype("O")
res.view(ndarray)[mask] = f
else:
- rdtype = _recursive_make_descr(self.dtype, "O")
+ rdtype = _replace_dtype_fields(self.dtype, "O")
res = self._data.astype(rdtype)
_recursive_printoption(res, m, f)
else:
@@ -5977,7 +5994,7 @@ class mvoid(MaskedArray):
if m is nomask:
return self._data.__str__()
printopt = masked_print_option
- rdtype = _recursive_make_descr(self._data.dtype, "O")
+ rdtype = _replace_dtype_fields(self._data.dtype, "O")
# temporary hack to fix gh-7493. A more permanent fix
# is proposed in gh-6053, after which the next two