diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-06-27 14:44:56 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-06-27 18:42:47 +0100 |
commit | 3600068426bfc86c1ccfa85a53394c6f6164a96f (patch) | |
tree | 2c657186a29480db97cf4f14b1b8a67798cf8066 /numpy/ma | |
parent | c80f6b640c3ca39594e3164637c57fec406275a3 (diff) | |
download | numpy-3600068426bfc86c1ccfa85a53394c6f6164a96f.tar.gz |
BUG: Overhaul *_fill_value functions
These now:
* Return void scalars instead of tuples for structured dtypes
* Actually accept non-trivial dtype arguments
* Work for nested structured and subarray dtypes
* Uniformly convert scalar types into dtypes
Also fixes #9315
Diffstat (limited to 'numpy/ma')
-rw-r--r-- | numpy/ma/core.py | 132 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 25 |
2 files changed, 75 insertions, 82 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index ed8652a45..c2b10d5f8 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -205,6 +205,31 @@ if 'float128' in ntypes.typeDict: min_filler.update([(np.float128, +np.inf)]) +def _recursive_fill_value(dtype, f): + """ + Recursively produce a fill value for `dtype`, calling f on scalar dtypes + """ + if dtype.names: + vals = tuple(_recursive_fill_value(dtype[name], f) for name in dtype.names) + return np.array(vals, dtype=dtype)[()] # decay to void scalar from 0d + elif dtype.subdtype: + subtype, shape = dtype.subdtype + subval = _recursive_fill_value(subtype, f) + return np.full(shape, subval) + else: + return f(dtype) + + +def _get_dtype_of(obj): + """ Convert the argument for *_fill_value into a dtype """ + if isinstance(obj, np.dtype): + return obj + elif hasattr(obj, 'dtype'): + return obj.dtype + else: + return np.asanyarray(obj).dtype + + def default_fill_value(obj): """ Return the default fill value for the argument object. @@ -223,6 +248,11 @@ def default_fill_value(obj): string 'N/A' ======== ======== + For structured types, a structured scalar is returned, with each field the + default fill value for its type. + + For subarray types, the fill value is an array of the same size containing + the default scalar fill value. Parameters ---------- @@ -245,62 +275,29 @@ def default_fill_value(obj): (1e+20+0j) """ - if hasattr(obj, 'dtype'): - return default_fill_value(obj.dtype) - elif isinstance(obj, np.dtype): - if obj.names: - return np.array(_recursive_set_default_fill_value(obj), - dtype=obj) - elif obj.subdtype: - return default_filler.get(obj.subdtype[0].kind, '?') - elif obj.kind in 'Mm': - return default_filler.get(obj.str[1:], '?') + def _scalar_fill_value(dtype): + if dtype.kind in 'Mm': + return default_filler.get(dtype.str[1:], '?') else: - return default_filler.get(obj.kind, '?') - elif isinstance(obj, float): - return default_filler['f'] - elif isinstance(obj, int) or isinstance(obj, long): - return default_filler['i'] - elif isinstance(obj, bytes): - return default_filler['S'] - elif isinstance(obj, unicode): - return default_filler['U'] - elif isinstance(obj, complex): - return default_filler['c'] - else: - return default_filler['O'] + return default_filler.get(dtype.kind, '?') - -def _recursive_extremum_fill_value(ndtype, extremum): - names = ndtype.names - if names: - deflist = [] - for name in names: - fval = _recursive_extremum_fill_value(ndtype[name], extremum) - deflist.append(fval) - return tuple(deflist) - elif ndtype.subdtype: - subtype, shape = ndtype.subdtype - subval = _recursive_extremum_fill_value(subtype, extremum) - return np.full(shape, subval) - else: - return extremum[ndtype] + dtype = _get_dtype_of(obj) + return _recursive_fill_value(dtype, _scalar_fill_value) def _extremum_fill_value(obj, extremum, extremum_name): - if hasattr(obj, 'dtype'): - return _recursive_extremum_fill_value(obj.dtype, extremum) - elif isinstance(obj, float): - return extremum[ntypes.typeDict['float_']] - elif isinstance(obj, int): - return extremum[ntypes.typeDict['int_']] - elif isinstance(obj, long): - return extremum[ntypes.typeDict['uint']] - elif isinstance(obj, np.dtype): - return extremum[obj] - else: - raise TypeError( - "Unsuitable type for calculating {}.".format(extremum_name)) + + def _scalar_fill_value(dtype): + try: + return extremum[dtype] + except KeyError: + raise TypeError( + "Unsuitable type {} for calculating {}." + .format(dtype, extremum_name) + ) + + dtype = _get_dtype_of(obj) + return _recursive_fill_value(dtype, _scalar_fill_value) def minimum_fill_value(obj): @@ -312,7 +309,7 @@ def minimum_fill_value(obj): Parameters ---------- - obj : ndarray or dtype + obj : ndarray, dtype or scalar An object that can be queried for it's numeric type. Returns @@ -363,7 +360,7 @@ def maximum_fill_value(obj): Parameters ---------- - obj : {ndarray, dtype} + obj : ndarray, dtype or scalar An object that can be queried for it's numeric type. Returns @@ -405,35 +402,6 @@ def maximum_fill_value(obj): return _extremum_fill_value(obj, max_filler, "maximum") -def _recursive_set_default_fill_value(dt): - """ - Create the default fill value for a structured dtype. - - Parameters - ---------- - dt: dtype - The structured dtype for which to create the fill value. - - Returns - ------- - val: tuple - A tuple of values corresponding to the default structured fill value. - - """ - deflist = [] - for name in dt.names: - currenttype = dt[name] - if currenttype.subdtype: - currenttype = currenttype.subdtype[0] - - if currenttype.names: - deflist.append( - tuple(_recursive_set_default_fill_value(currenttype))) - else: - deflist.append(default_fill_value(currenttype)) - return tuple(deflist) - - def _recursive_set_fill_value(fillvalue, dt): """ Create a fill value for a structured dtype. diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 3a401caed..816e149a9 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1788,6 +1788,26 @@ class TestFillingValues(TestCase): assert_equal(b['a']._data, a._data) assert_equal(b['a'].fill_value, a.fill_value) + def test_default_fill_value(self): + # check all calling conventions + f1 = default_fill_value(1.) + f2 = default_fill_value(np.array(1.)) + f3 = default_fill_value(np.array(1.).dtype) + assert_equal(f1, f2) + assert_equal(f1, f3) + + def test_default_fill_value_structured(self): + fields = array([(1, 1, 1)], + dtype=[('i', int), ('s', '|S8'), ('f', float)]) + + f1 = default_fill_value(fields) + f2 = default_fill_value(fields.dtype) + expected = np.array((default_fill_value(0), + default_fill_value('0'), + default_fill_value(0.)), dtype=fields.dtype) + assert_equal(f1, expected) + assert_equal(f2, expected) + def test_fillvalue(self): # Yet more fun with the fill_value data = masked_array([1, 2, 3], fill_value=-999) @@ -1863,17 +1883,20 @@ class TestFillingValues(TestCase): a = array([(1, (2, 3)), (4, (5, 6))], dtype=[('A', int), ('B', [('BA', int), ('BB', int)])]) test = a.fill_value + assert_equal(test.dtype, a.dtype) assert_equal(test['A'], default_fill_value(a['A'])) assert_equal(test['B']['BA'], default_fill_value(a['B']['BA'])) assert_equal(test['B']['BB'], default_fill_value(a['B']['BB'])) test = minimum_fill_value(a) + assert_equal(test.dtype, a.dtype) assert_equal(test[0], minimum_fill_value(a['A'])) assert_equal(test[1][0], minimum_fill_value(a['B']['BA'])) assert_equal(test[1][1], minimum_fill_value(a['B']['BB'])) assert_equal(test[1], minimum_fill_value(a['B'])) test = maximum_fill_value(a) + assert_equal(test.dtype, a.dtype) assert_equal(test[0], maximum_fill_value(a['A'])) assert_equal(test[1][0], maximum_fill_value(a['B']['BA'])) assert_equal(test[1][1], maximum_fill_value(a['B']['BB'])) @@ -1883,9 +1906,11 @@ class TestFillingValues(TestCase): a = array(([2, 3, 4],), dtype=[('value', np.int8, 3)]) test = minimum_fill_value(a) + assert_equal(test.dtype, a.dtype) assert_equal(test[0], np.full(3, minimum_fill_value(a['value']))) test = maximum_fill_value(a) + assert_equal(test.dtype, a.dtype) assert_equal(test[0], np.full(3, maximum_fill_value(a['value']))) def test_fillvalue_individual_fields(self): |