diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-03 12:49:14 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-03-03 12:49:14 +0000 |
commit | b1a70570c2110734ed6a371df2224a34aa20b9ec (patch) | |
tree | 195e29f09749a412d686e3ae36ca8c78b20342f0 /numpy/ma | |
parent | 02f33887b751d288ec2238143aa8048a74f731f2 (diff) | |
download | numpy-b1a70570c2110734ed6a371df2224a34aa20b9ec.tar.gz |
MAINT: Don't call internal recursion helpers elsewhere
Cleans up make_mask_descr, and adds a private _replace_dtype_fields, which
we need elsewhere to replace all dtypes with `object` instead of `bool`.
This new version also removes repeated calls to `np.dtype`, doing the conversion
only once.
Diffstat (limited to 'numpy/ma')
-rw-r--r-- | numpy/ma/core.py | 59 |
1 files changed, 38 insertions, 21 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index ad67c582a..5ef4e1369 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -1286,32 +1286,52 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1) ############################################################################### -def _recursive_make_descr(datatype, newtype=bool_): - "Private function allowing recursion in make_descr." +def _replace_dtype_fields_recursive(dtype, primitive_dtype): + "Private function allowing recursion in _replace_dtype_fields." + _recurse = _replace_dtype_fields_recursive + # Do we have some name fields ? - if datatype.names: + if dtype.names: descr = [] - for name in datatype.names: - field = datatype.fields[name] + for name in dtype.names: + field = dtype.fields[name] if len(field) == 3: # Prepend the title to the name name = (field[-1], name) - descr.append((name, _recursive_make_descr(field[0], newtype))) + descr.append((name, _recurse(field[0], primitive_dtype))) + new_dtype = np.dtype(descr) + # Is this some kind of composite a la (np.float,2) - elif datatype.subdtype: - descr = list(datatype.subdtype) - descr[0] = _recursive_make_descr(datatype.subdtype[0], newtype) - descr = tuple(descr) + elif dtype.subdtype: + descr = list(dtype.subdtype) + descr[0] = _recurse(dtype.subdtype[0], primitive_dtype) + new_dtype = np.dtype(tuple(descr)) + + # this is a primitive type, so do a direct replacement else: - descr = newtype + new_dtype = primitive_dtype - new_dtype = np.dtype(descr) - if new_dtype == datatype: - new_dtype = datatype + # preserve identity of dtypes + if new_dtype == dtype: + new_dtype = dtype return new_dtype +def _replace_dtype_fields(dtype, primitive_dtype): + """ + Construct a dtype description list from a given dtype. + + Returns a new dtype object, with all fields and subtypes in the given type + recursively replaced with `primitive_dtype`. + + Arguments are coerced to dtypes first. + """ + dtype = np.dtype(dtype) + primitive_dtype = np.dtype(primitive_dtype) + return _replace_dtype_fields_recursive(dtype, primitive_dtype) + + def make_mask_descr(ndtype): """ Construct a dtype description list from a given dtype. @@ -1339,13 +1359,10 @@ def make_mask_descr(ndtype): >>> ma.make_mask_descr(dtype) dtype([('foo', '|b1'), ('bar', '|b1')]) >>> ma.make_mask_descr(np.float32) - <type 'numpy.bool_'> + dtype('bool') """ - # Make sure we do have a dtype - if not isinstance(ndtype, np.dtype): - ndtype = np.dtype(ndtype) - return _recursive_make_descr(ndtype, np.bool) + return _replace_dtype_fields(ndtype, MaskType) def getmask(a): @@ -3831,7 +3848,7 @@ class MaskedArray(ndarray): res = data.astype("O") res.view(ndarray)[mask] = f else: - rdtype = _recursive_make_descr(self.dtype, "O") + rdtype = _replace_dtype_fields(self.dtype, "O") res = self._data.astype(rdtype) _recursive_printoption(res, m, f) else: @@ -5977,7 +5994,7 @@ class mvoid(MaskedArray): if m is nomask: return self._data.__str__() printopt = masked_print_option - rdtype = _recursive_make_descr(self._data.dtype, "O") + rdtype = _replace_dtype_fields(self._data.dtype, "O") # temporary hack to fix gh-7493. A more permanent fix # is proposed in gh-6053, after which the next two |