diff options
Diffstat (limited to 'numpy/ma')
-rw-r--r-- | numpy/ma/core.py | 153 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 72 |
2 files changed, 117 insertions, 108 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index a925eab15..ab4364706 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -205,6 +205,31 @@ if 'float128' in ntypes.typeDict: min_filler.update([(np.float128, +np.inf)]) +def _recursive_fill_value(dtype, f): + """ + Recursively produce a fill value for `dtype`, calling f on scalar dtypes + """ + if dtype.names: + vals = tuple(_recursive_fill_value(dtype[name], f) for name in dtype.names) + return np.array(vals, dtype=dtype)[()] # decay to void scalar from 0d + elif dtype.subdtype: + subtype, shape = dtype.subdtype + subval = _recursive_fill_value(subtype, f) + return np.full(shape, subval) + else: + return f(dtype) + + +def _get_dtype_of(obj): + """ Convert the argument for *_fill_value into a dtype """ + if isinstance(obj, np.dtype): + return obj + elif hasattr(obj, 'dtype'): + return obj.dtype + else: + return np.asanyarray(obj).dtype + + def default_fill_value(obj): """ Return the default fill value for the argument object. @@ -223,6 +248,11 @@ def default_fill_value(obj): string 'N/A' ======== ======== + For structured types, a structured scalar is returned, with each field the + default fill value for its type. + + For subarray types, the fill value is an array of the same size containing + the default scalar fill value. Parameters ---------- @@ -245,39 +275,29 @@ def default_fill_value(obj): (1e+20+0j) """ - if hasattr(obj, 'dtype'): - defval = _check_fill_value(None, obj.dtype) - elif isinstance(obj, np.dtype): - if obj.subdtype: - defval = default_filler.get(obj.subdtype[0].kind, '?') - elif obj.kind in 'Mm': - defval = default_filler.get(obj.str[1:], '?') + def _scalar_fill_value(dtype): + if dtype.kind in 'Mm': + return default_filler.get(dtype.str[1:], '?') else: - defval = default_filler.get(obj.kind, '?') - elif isinstance(obj, float): - defval = default_filler['f'] - elif isinstance(obj, int) or isinstance(obj, long): - defval = default_filler['i'] - elif isinstance(obj, bytes): - defval = default_filler['S'] - elif isinstance(obj, unicode): - defval = default_filler['U'] - elif isinstance(obj, complex): - defval = default_filler['c'] - else: - defval = default_filler['O'] - return defval + return default_filler.get(dtype.kind, '?') + dtype = _get_dtype_of(obj) + return _recursive_fill_value(dtype, _scalar_fill_value) -def _recursive_extremum_fill_value(ndtype, extremum): - names = ndtype.names - if names: - deflist = [] - for name in names: - fval = _recursive_extremum_fill_value(ndtype[name], extremum) - deflist.append(fval) - return tuple(deflist) - return extremum[ndtype] + +def _extremum_fill_value(obj, extremum, extremum_name): + + def _scalar_fill_value(dtype): + try: + return extremum[dtype] + except KeyError: + raise TypeError( + "Unsuitable type {} for calculating {}." + .format(dtype, extremum_name) + ) + + dtype = _get_dtype_of(obj) + return _recursive_fill_value(dtype, _scalar_fill_value) def minimum_fill_value(obj): @@ -289,7 +309,7 @@ def minimum_fill_value(obj): Parameters ---------- - obj : ndarray or dtype + obj : ndarray, dtype or scalar An object that can be queried for it's numeric type. Returns @@ -328,19 +348,7 @@ def minimum_fill_value(obj): inf """ - errmsg = "Unsuitable type for calculating minimum." - if hasattr(obj, 'dtype'): - return _recursive_extremum_fill_value(obj.dtype, min_filler) - elif isinstance(obj, float): - return min_filler[ntypes.typeDict['float_']] - elif isinstance(obj, int): - return min_filler[ntypes.typeDict['int_']] - elif isinstance(obj, long): - return min_filler[ntypes.typeDict['uint']] - elif isinstance(obj, np.dtype): - return min_filler[obj] - else: - raise TypeError(errmsg) + return _extremum_fill_value(obj, min_filler, "minimum") def maximum_fill_value(obj): @@ -352,7 +360,7 @@ def maximum_fill_value(obj): Parameters ---------- - obj : {ndarray, dtype} + obj : ndarray, dtype or scalar An object that can be queried for it's numeric type. Returns @@ -391,48 +399,7 @@ def maximum_fill_value(obj): -inf """ - errmsg = "Unsuitable type for calculating maximum." - if hasattr(obj, 'dtype'): - return _recursive_extremum_fill_value(obj.dtype, max_filler) - elif isinstance(obj, float): - return max_filler[ntypes.typeDict['float_']] - elif isinstance(obj, int): - return max_filler[ntypes.typeDict['int_']] - elif isinstance(obj, long): - return max_filler[ntypes.typeDict['uint']] - elif isinstance(obj, np.dtype): - return max_filler[obj] - else: - raise TypeError(errmsg) - - -def _recursive_set_default_fill_value(dt): - """ - Create the default fill value for a structured dtype. - - Parameters - ---------- - dt: dtype - The structured dtype for which to create the fill value. - - Returns - ------- - val: tuple - A tuple of values corresponding to the default structured fill value. - - """ - deflist = [] - for name in dt.names: - currenttype = dt[name] - if currenttype.subdtype: - currenttype = currenttype.subdtype[0] - - if currenttype.names: - deflist.append( - tuple(_recursive_set_default_fill_value(currenttype))) - else: - deflist.append(default_fill_value(currenttype)) - return tuple(deflist) + return _extremum_fill_value(obj, max_filler, "maximum") def _recursive_set_fill_value(fillvalue, dt): @@ -471,22 +438,16 @@ def _check_fill_value(fill_value, ndtype): """ Private function validating the given `fill_value` for the given dtype. - If fill_value is None, it is set to the default corresponding to the dtype - if this latter is standard (no fields). If the datatype is flexible (named - fields), fill_value is set to a tuple whose elements are the default fill - values corresponding to each field. + If fill_value is None, it is set to the default corresponding to the dtype. If fill_value is not None, its value is forced to the given dtype. + The result is always a 0d array. """ ndtype = np.dtype(ndtype) fields = ndtype.fields if fill_value is None: - if fields: - fill_value = np.array(_recursive_set_default_fill_value(ndtype), - dtype=ndtype) - else: - fill_value = default_fill_value(ndtype) + fill_value = default_fill_value(ndtype) elif fields: fdtype = [(_[0], _[1]) for _ in ndtype.descr] if isinstance(fill_value, (ndarray, np.void)): diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 93f6d4be0..707fcd1de 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1788,6 +1788,26 @@ class TestFillingValues(TestCase): assert_equal(b['a']._data, a._data) assert_equal(b['a'].fill_value, a.fill_value) + def test_default_fill_value(self): + # check all calling conventions + f1 = default_fill_value(1.) + f2 = default_fill_value(np.array(1.)) + f3 = default_fill_value(np.array(1.).dtype) + assert_equal(f1, f2) + assert_equal(f1, f3) + + def test_default_fill_value_structured(self): + fields = array([(1, 1, 1)], + dtype=[('i', int), ('s', '|S8'), ('f', float)]) + + f1 = default_fill_value(fields) + f2 = default_fill_value(fields.dtype) + expected = np.array((default_fill_value(0), + default_fill_value('0'), + default_fill_value(0.)), dtype=fields.dtype) + assert_equal(f1, expected) + assert_equal(f2, expected) + def test_fillvalue(self): # Yet more fun with the fill_value data = masked_array([1, 2, 3], fill_value=-999) @@ -1863,22 +1883,36 @@ class TestFillingValues(TestCase): a = array([(1, (2, 3)), (4, (5, 6))], dtype=[('A', int), ('B', [('BA', int), ('BB', int)])]) test = a.fill_value + assert_equal(test.dtype, a.dtype) assert_equal(test['A'], default_fill_value(a['A'])) assert_equal(test['B']['BA'], default_fill_value(a['B']['BA'])) assert_equal(test['B']['BB'], default_fill_value(a['B']['BB'])) test = minimum_fill_value(a) + assert_equal(test.dtype, a.dtype) assert_equal(test[0], minimum_fill_value(a['A'])) assert_equal(test[1][0], minimum_fill_value(a['B']['BA'])) assert_equal(test[1][1], minimum_fill_value(a['B']['BB'])) assert_equal(test[1], minimum_fill_value(a['B'])) test = maximum_fill_value(a) + assert_equal(test.dtype, a.dtype) assert_equal(test[0], maximum_fill_value(a['A'])) assert_equal(test[1][0], maximum_fill_value(a['B']['BA'])) assert_equal(test[1][1], maximum_fill_value(a['B']['BB'])) assert_equal(test[1], maximum_fill_value(a['B'])) + def test_extremum_fill_value_subdtype(self): + a = array(([2, 3, 4],), dtype=[('value', np.int8, 3)]) + + test = minimum_fill_value(a) + assert_equal(test.dtype, a.dtype) + assert_equal(test[0], np.full(3, minimum_fill_value(a['value']))) + + test = maximum_fill_value(a) + assert_equal(test.dtype, a.dtype) + assert_equal(test[0], np.full(3, maximum_fill_value(a['value']))) + def test_fillvalue_individual_fields(self): # Test setting fill_value on individual fields ndtype = [('a', int), ('b', int)] @@ -3097,27 +3131,41 @@ class TestMaskedArrayMethods(TestCase): assert_equal(am, an) def test_sort_flexible(self): - # Test sort on flexible dtype. + # Test sort on structured dtype. a = array( data=[(3, 3), (3, 2), (2, 2), (2, 1), (1, 0), (1, 1), (1, 2)], mask=[(0, 0), (0, 1), (0, 0), (0, 0), (1, 0), (0, 0), (0, 0)], dtype=[('A', int), ('B', int)]) - - test = sort(a) - b = array( + mask_last = array( data=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3), (3, 2), (1, 0)], mask=[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (1, 0)], dtype=[('A', int), ('B', int)]) - assert_equal(test, b) - assert_equal(test.mask, b.mask) + mask_first = array( + data=[(1, 0), (1, 1), (1, 2), (2, 1), (2, 2), (3, 2), (3, 3)], + mask=[(1, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (0, 0)], + dtype=[('A', int), ('B', int)]) + + test = sort(a) + assert_equal(test, mask_last) + assert_equal(test.mask, mask_last.mask) test = sort(a, endwith=False) - b = array( - data=[(1, 0), (1, 1), (1, 2), (2, 1), (2, 2), (3, 2), (3, 3), ], - mask=[(1, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (0, 0), ], - dtype=[('A', int), ('B', int)]) - assert_equal(test, b) - assert_equal(test.mask, b.mask) + assert_equal(test, mask_first) + assert_equal(test.mask, mask_first.mask) + + # Test sort on dtype with subarray (gh-8069) + dt = np.dtype([('v', int, 2)]) + a = a.view(dt) + mask_last = mask_last.view(dt) + mask_first = mask_first.view(dt) + + test = sort(a) + assert_equal(test, mask_last) + assert_equal(test.mask, mask_last.mask) + + test = sort(a, endwith=False) + assert_equal(test, mask_first) + assert_equal(test.mask, mask_first.mask) def test_argsort(self): # Test argsort |