diff options
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r-- | numpy/ma/core.py | 65 |
1 files changed, 50 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 29b818c06..1ef063eda 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -371,24 +371,61 @@ def maximum_fill_value(obj): raise TypeError(errmsg) -def _recursive_set_default_fill_value(dtypedescr): +def _recursive_set_default_fill_value(dt): + """ + Creates the default fill value for structured dtypes + + Parameters + ---------- + dt: dtype + structured dtype + + Returns + ------- + val: tuple + tuple of values corresponding to the default structured fill value + + """ deflist = [] - for currentdescr in dtypedescr: - currenttype = currentdescr[1] - if isinstance(currenttype, list): + for name in dt.names: + currenttype = dt[name] + if currenttype.subdtype: + currenttype = currenttype.subdtype[0] + + if currenttype.names: deflist.append( tuple(_recursive_set_default_fill_value(currenttype))) else: - deflist.append(default_fill_value(np.dtype(currenttype))) + deflist.append(default_fill_value(currenttype)) return tuple(deflist) -def _recursive_set_fill_value(fillvalue, dtypedescr): - fillvalue = np.resize(fillvalue, len(dtypedescr)) +def _recursive_set_fill_value(fillvalue, dt): + """ + Creates a fill value for structured dtypes + + Parameters + ---------- + fillvalue: scalar or array_like + scalar or array representing the fill value. If it is of shorter + length than the number of fields in dt, it will be resized. + dt: dtype + structured dtype + + Returns + ------- + val: tuple + tuple of values corresponding to the structured fill value + + """ + fillvalue = np.resize(fillvalue, len(dt.names)) output_value = [] - for (fval, descr) in zip(fillvalue, dtypedescr): - cdtype = descr[1] - if isinstance(cdtype, list): + for (fval, name) in zip(fillvalue, dt.names): + cdtype = dt[name] + if cdtype.subdtype: + cdtype = cdtype.subdtype[0] + + if cdtype.names: output_value.append(tuple(_recursive_set_fill_value(fval, cdtype))) else: output_value.append(np.array(fval, dtype=cdtype).item()) @@ -411,9 +448,8 @@ def _check_fill_value(fill_value, ndtype): fields = ndtype.fields if fill_value is None: if fields: - descr = ndtype.descr - fill_value = np.array(_recursive_set_default_fill_value(descr), - dtype=ndtype,) + fill_value = np.array(_recursive_set_default_fill_value(ndtype), + dtype=ndtype) else: fill_value = default_fill_value(ndtype) elif fields: @@ -425,9 +461,8 @@ def _check_fill_value(fill_value, ndtype): err_msg = "Unable to transform %s to dtype %s" raise ValueError(err_msg % (fill_value, fdtype)) else: - descr = ndtype.descr fill_value = np.asarray(fill_value, dtype=object) - fill_value = np.array(_recursive_set_fill_value(fill_value, descr), + fill_value = np.array(_recursive_set_fill_value(fill_value, ndtype), dtype=ndtype) else: if isinstance(fill_value, basestring) and (ndtype.char not in 'OSVU'): |