summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py153
-rw-r--r--numpy/ma/tests/test_core.py72
2 files changed, 117 insertions, 108 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index a925eab15..ab4364706 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -205,6 +205,31 @@ if 'float128' in ntypes.typeDict:
min_filler.update([(np.float128, +np.inf)])
+def _recursive_fill_value(dtype, f):
+ """
+ Recursively produce a fill value for `dtype`, calling f on scalar dtypes
+ """
+ if dtype.names:
+ vals = tuple(_recursive_fill_value(dtype[name], f) for name in dtype.names)
+ return np.array(vals, dtype=dtype)[()] # decay to void scalar from 0d
+ elif dtype.subdtype:
+ subtype, shape = dtype.subdtype
+ subval = _recursive_fill_value(subtype, f)
+ return np.full(shape, subval)
+ else:
+ return f(dtype)
+
+
+def _get_dtype_of(obj):
+ """ Convert the argument for *_fill_value into a dtype """
+ if isinstance(obj, np.dtype):
+ return obj
+ elif hasattr(obj, 'dtype'):
+ return obj.dtype
+ else:
+ return np.asanyarray(obj).dtype
+
+
def default_fill_value(obj):
"""
Return the default fill value for the argument object.
@@ -223,6 +248,11 @@ def default_fill_value(obj):
string 'N/A'
======== ========
+ For structured types, a structured scalar is returned, with each field the
+ default fill value for its type.
+
+ For subarray types, the fill value is an array of the same size containing
+ the default scalar fill value.
Parameters
----------
@@ -245,39 +275,29 @@ def default_fill_value(obj):
(1e+20+0j)
"""
- if hasattr(obj, 'dtype'):
- defval = _check_fill_value(None, obj.dtype)
- elif isinstance(obj, np.dtype):
- if obj.subdtype:
- defval = default_filler.get(obj.subdtype[0].kind, '?')
- elif obj.kind in 'Mm':
- defval = default_filler.get(obj.str[1:], '?')
+ def _scalar_fill_value(dtype):
+ if dtype.kind in 'Mm':
+ return default_filler.get(dtype.str[1:], '?')
else:
- defval = default_filler.get(obj.kind, '?')
- elif isinstance(obj, float):
- defval = default_filler['f']
- elif isinstance(obj, int) or isinstance(obj, long):
- defval = default_filler['i']
- elif isinstance(obj, bytes):
- defval = default_filler['S']
- elif isinstance(obj, unicode):
- defval = default_filler['U']
- elif isinstance(obj, complex):
- defval = default_filler['c']
- else:
- defval = default_filler['O']
- return defval
+ return default_filler.get(dtype.kind, '?')
+ dtype = _get_dtype_of(obj)
+ return _recursive_fill_value(dtype, _scalar_fill_value)
-def _recursive_extremum_fill_value(ndtype, extremum):
- names = ndtype.names
- if names:
- deflist = []
- for name in names:
- fval = _recursive_extremum_fill_value(ndtype[name], extremum)
- deflist.append(fval)
- return tuple(deflist)
- return extremum[ndtype]
+
+def _extremum_fill_value(obj, extremum, extremum_name):
+
+ def _scalar_fill_value(dtype):
+ try:
+ return extremum[dtype]
+ except KeyError:
+ raise TypeError(
+ "Unsuitable type {} for calculating {}."
+ .format(dtype, extremum_name)
+ )
+
+ dtype = _get_dtype_of(obj)
+ return _recursive_fill_value(dtype, _scalar_fill_value)
def minimum_fill_value(obj):
@@ -289,7 +309,7 @@ def minimum_fill_value(obj):
Parameters
----------
- obj : ndarray or dtype
+ obj : ndarray, dtype or scalar
An object that can be queried for it's numeric type.
Returns
@@ -328,19 +348,7 @@ def minimum_fill_value(obj):
inf
"""
- errmsg = "Unsuitable type for calculating minimum."
- if hasattr(obj, 'dtype'):
- return _recursive_extremum_fill_value(obj.dtype, min_filler)
- elif isinstance(obj, float):
- return min_filler[ntypes.typeDict['float_']]
- elif isinstance(obj, int):
- return min_filler[ntypes.typeDict['int_']]
- elif isinstance(obj, long):
- return min_filler[ntypes.typeDict['uint']]
- elif isinstance(obj, np.dtype):
- return min_filler[obj]
- else:
- raise TypeError(errmsg)
+ return _extremum_fill_value(obj, min_filler, "minimum")
def maximum_fill_value(obj):
@@ -352,7 +360,7 @@ def maximum_fill_value(obj):
Parameters
----------
- obj : {ndarray, dtype}
+ obj : ndarray, dtype or scalar
An object that can be queried for it's numeric type.
Returns
@@ -391,48 +399,7 @@ def maximum_fill_value(obj):
-inf
"""
- errmsg = "Unsuitable type for calculating maximum."
- if hasattr(obj, 'dtype'):
- return _recursive_extremum_fill_value(obj.dtype, max_filler)
- elif isinstance(obj, float):
- return max_filler[ntypes.typeDict['float_']]
- elif isinstance(obj, int):
- return max_filler[ntypes.typeDict['int_']]
- elif isinstance(obj, long):
- return max_filler[ntypes.typeDict['uint']]
- elif isinstance(obj, np.dtype):
- return max_filler[obj]
- else:
- raise TypeError(errmsg)
-
-
-def _recursive_set_default_fill_value(dt):
- """
- Create the default fill value for a structured dtype.
-
- Parameters
- ----------
- dt: dtype
- The structured dtype for which to create the fill value.
-
- Returns
- -------
- val: tuple
- A tuple of values corresponding to the default structured fill value.
-
- """
- deflist = []
- for name in dt.names:
- currenttype = dt[name]
- if currenttype.subdtype:
- currenttype = currenttype.subdtype[0]
-
- if currenttype.names:
- deflist.append(
- tuple(_recursive_set_default_fill_value(currenttype)))
- else:
- deflist.append(default_fill_value(currenttype))
- return tuple(deflist)
+ return _extremum_fill_value(obj, max_filler, "maximum")
def _recursive_set_fill_value(fillvalue, dt):
@@ -471,22 +438,16 @@ def _check_fill_value(fill_value, ndtype):
"""
Private function validating the given `fill_value` for the given dtype.
- If fill_value is None, it is set to the default corresponding to the dtype
- if this latter is standard (no fields). If the datatype is flexible (named
- fields), fill_value is set to a tuple whose elements are the default fill
- values corresponding to each field.
+ If fill_value is None, it is set to the default corresponding to the dtype.
If fill_value is not None, its value is forced to the given dtype.
+ The result is always a 0d array.
"""
ndtype = np.dtype(ndtype)
fields = ndtype.fields
if fill_value is None:
- if fields:
- fill_value = np.array(_recursive_set_default_fill_value(ndtype),
- dtype=ndtype)
- else:
- fill_value = default_fill_value(ndtype)
+ fill_value = default_fill_value(ndtype)
elif fields:
fdtype = [(_[0], _[1]) for _ in ndtype.descr]
if isinstance(fill_value, (ndarray, np.void)):
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 93f6d4be0..707fcd1de 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1788,6 +1788,26 @@ class TestFillingValues(TestCase):
assert_equal(b['a']._data, a._data)
assert_equal(b['a'].fill_value, a.fill_value)
+ def test_default_fill_value(self):
+ # check all calling conventions
+ f1 = default_fill_value(1.)
+ f2 = default_fill_value(np.array(1.))
+ f3 = default_fill_value(np.array(1.).dtype)
+ assert_equal(f1, f2)
+ assert_equal(f1, f3)
+
+ def test_default_fill_value_structured(self):
+ fields = array([(1, 1, 1)],
+ dtype=[('i', int), ('s', '|S8'), ('f', float)])
+
+ f1 = default_fill_value(fields)
+ f2 = default_fill_value(fields.dtype)
+ expected = np.array((default_fill_value(0),
+ default_fill_value('0'),
+ default_fill_value(0.)), dtype=fields.dtype)
+ assert_equal(f1, expected)
+ assert_equal(f2, expected)
+
def test_fillvalue(self):
# Yet more fun with the fill_value
data = masked_array([1, 2, 3], fill_value=-999)
@@ -1863,22 +1883,36 @@ class TestFillingValues(TestCase):
a = array([(1, (2, 3)), (4, (5, 6))],
dtype=[('A', int), ('B', [('BA', int), ('BB', int)])])
test = a.fill_value
+ assert_equal(test.dtype, a.dtype)
assert_equal(test['A'], default_fill_value(a['A']))
assert_equal(test['B']['BA'], default_fill_value(a['B']['BA']))
assert_equal(test['B']['BB'], default_fill_value(a['B']['BB']))
test = minimum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
assert_equal(test[0], minimum_fill_value(a['A']))
assert_equal(test[1][0], minimum_fill_value(a['B']['BA']))
assert_equal(test[1][1], minimum_fill_value(a['B']['BB']))
assert_equal(test[1], minimum_fill_value(a['B']))
test = maximum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
assert_equal(test[0], maximum_fill_value(a['A']))
assert_equal(test[1][0], maximum_fill_value(a['B']['BA']))
assert_equal(test[1][1], maximum_fill_value(a['B']['BB']))
assert_equal(test[1], maximum_fill_value(a['B']))
+ def test_extremum_fill_value_subdtype(self):
+ a = array(([2, 3, 4],), dtype=[('value', np.int8, 3)])
+
+ test = minimum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
+ assert_equal(test[0], np.full(3, minimum_fill_value(a['value'])))
+
+ test = maximum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
+ assert_equal(test[0], np.full(3, maximum_fill_value(a['value'])))
+
def test_fillvalue_individual_fields(self):
# Test setting fill_value on individual fields
ndtype = [('a', int), ('b', int)]
@@ -3097,27 +3131,41 @@ class TestMaskedArrayMethods(TestCase):
assert_equal(am, an)
def test_sort_flexible(self):
- # Test sort on flexible dtype.
+ # Test sort on structured dtype.
a = array(
data=[(3, 3), (3, 2), (2, 2), (2, 1), (1, 0), (1, 1), (1, 2)],
mask=[(0, 0), (0, 1), (0, 0), (0, 0), (1, 0), (0, 0), (0, 0)],
dtype=[('A', int), ('B', int)])
-
- test = sort(a)
- b = array(
+ mask_last = array(
data=[(1, 1), (1, 2), (2, 1), (2, 2), (3, 3), (3, 2), (1, 0)],
mask=[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (1, 0)],
dtype=[('A', int), ('B', int)])
- assert_equal(test, b)
- assert_equal(test.mask, b.mask)
+ mask_first = array(
+ data=[(1, 0), (1, 1), (1, 2), (2, 1), (2, 2), (3, 2), (3, 3)],
+ mask=[(1, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (0, 0)],
+ dtype=[('A', int), ('B', int)])
+
+ test = sort(a)
+ assert_equal(test, mask_last)
+ assert_equal(test.mask, mask_last.mask)
test = sort(a, endwith=False)
- b = array(
- data=[(1, 0), (1, 1), (1, 2), (2, 1), (2, 2), (3, 2), (3, 3), ],
- mask=[(1, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 1), (0, 0), ],
- dtype=[('A', int), ('B', int)])
- assert_equal(test, b)
- assert_equal(test.mask, b.mask)
+ assert_equal(test, mask_first)
+ assert_equal(test.mask, mask_first.mask)
+
+ # Test sort on dtype with subarray (gh-8069)
+ dt = np.dtype([('v', int, 2)])
+ a = a.view(dt)
+ mask_last = mask_last.view(dt)
+ mask_first = mask_first.view(dt)
+
+ test = sort(a)
+ assert_equal(test, mask_last)
+ assert_equal(test.mask, mask_last.mask)
+
+ test = sort(a, endwith=False)
+ assert_equal(test, mask_first)
+ assert_equal(test.mask, mask_first.mask)
def test_argsort(self):
# Test argsort