summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-06-27 14:44:56 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-06-27 18:42:47 +0100
commit3600068426bfc86c1ccfa85a53394c6f6164a96f (patch)
tree2c657186a29480db97cf4f14b1b8a67798cf8066 /numpy/ma
parentc80f6b640c3ca39594e3164637c57fec406275a3 (diff)
downloadnumpy-3600068426bfc86c1ccfa85a53394c6f6164a96f.tar.gz
BUG: Overhaul *_fill_value functions
These now: * Return void scalars instead of tuples for structured dtypes * Actually accept non-trivial dtype arguments * Work for nested structured and subarray dtypes * Uniformly convert scalar types into dtypes Also fixes #9315
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py132
-rw-r--r--numpy/ma/tests/test_core.py25
2 files changed, 75 insertions, 82 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index ed8652a45..c2b10d5f8 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -205,6 +205,31 @@ if 'float128' in ntypes.typeDict:
min_filler.update([(np.float128, +np.inf)])
+def _recursive_fill_value(dtype, f):
+ """
+ Recursively produce a fill value for `dtype`, calling f on scalar dtypes
+ """
+ if dtype.names:
+ vals = tuple(_recursive_fill_value(dtype[name], f) for name in dtype.names)
+ return np.array(vals, dtype=dtype)[()] # decay to void scalar from 0d
+ elif dtype.subdtype:
+ subtype, shape = dtype.subdtype
+ subval = _recursive_fill_value(subtype, f)
+ return np.full(shape, subval)
+ else:
+ return f(dtype)
+
+
+def _get_dtype_of(obj):
+ """ Convert the argument for *_fill_value into a dtype """
+ if isinstance(obj, np.dtype):
+ return obj
+ elif hasattr(obj, 'dtype'):
+ return obj.dtype
+ else:
+ return np.asanyarray(obj).dtype
+
+
def default_fill_value(obj):
"""
Return the default fill value for the argument object.
@@ -223,6 +248,11 @@ def default_fill_value(obj):
string 'N/A'
======== ========
+ For structured types, a structured scalar is returned, with each field the
+ default fill value for its type.
+
+ For subarray types, the fill value is an array of the same size containing
+ the default scalar fill value.
Parameters
----------
@@ -245,62 +275,29 @@ def default_fill_value(obj):
(1e+20+0j)
"""
- if hasattr(obj, 'dtype'):
- return default_fill_value(obj.dtype)
- elif isinstance(obj, np.dtype):
- if obj.names:
- return np.array(_recursive_set_default_fill_value(obj),
- dtype=obj)
- elif obj.subdtype:
- return default_filler.get(obj.subdtype[0].kind, '?')
- elif obj.kind in 'Mm':
- return default_filler.get(obj.str[1:], '?')
+ def _scalar_fill_value(dtype):
+ if dtype.kind in 'Mm':
+ return default_filler.get(dtype.str[1:], '?')
else:
- return default_filler.get(obj.kind, '?')
- elif isinstance(obj, float):
- return default_filler['f']
- elif isinstance(obj, int) or isinstance(obj, long):
- return default_filler['i']
- elif isinstance(obj, bytes):
- return default_filler['S']
- elif isinstance(obj, unicode):
- return default_filler['U']
- elif isinstance(obj, complex):
- return default_filler['c']
- else:
- return default_filler['O']
+ return default_filler.get(dtype.kind, '?')
-
-def _recursive_extremum_fill_value(ndtype, extremum):
- names = ndtype.names
- if names:
- deflist = []
- for name in names:
- fval = _recursive_extremum_fill_value(ndtype[name], extremum)
- deflist.append(fval)
- return tuple(deflist)
- elif ndtype.subdtype:
- subtype, shape = ndtype.subdtype
- subval = _recursive_extremum_fill_value(subtype, extremum)
- return np.full(shape, subval)
- else:
- return extremum[ndtype]
+ dtype = _get_dtype_of(obj)
+ return _recursive_fill_value(dtype, _scalar_fill_value)
def _extremum_fill_value(obj, extremum, extremum_name):
- if hasattr(obj, 'dtype'):
- return _recursive_extremum_fill_value(obj.dtype, extremum)
- elif isinstance(obj, float):
- return extremum[ntypes.typeDict['float_']]
- elif isinstance(obj, int):
- return extremum[ntypes.typeDict['int_']]
- elif isinstance(obj, long):
- return extremum[ntypes.typeDict['uint']]
- elif isinstance(obj, np.dtype):
- return extremum[obj]
- else:
- raise TypeError(
- "Unsuitable type for calculating {}.".format(extremum_name))
+
+ def _scalar_fill_value(dtype):
+ try:
+ return extremum[dtype]
+ except KeyError:
+ raise TypeError(
+ "Unsuitable type {} for calculating {}."
+ .format(dtype, extremum_name)
+ )
+
+ dtype = _get_dtype_of(obj)
+ return _recursive_fill_value(dtype, _scalar_fill_value)
def minimum_fill_value(obj):
@@ -312,7 +309,7 @@ def minimum_fill_value(obj):
Parameters
----------
- obj : ndarray or dtype
+ obj : ndarray, dtype or scalar
An object that can be queried for it's numeric type.
Returns
@@ -363,7 +360,7 @@ def maximum_fill_value(obj):
Parameters
----------
- obj : {ndarray, dtype}
+ obj : ndarray, dtype or scalar
An object that can be queried for it's numeric type.
Returns
@@ -405,35 +402,6 @@ def maximum_fill_value(obj):
return _extremum_fill_value(obj, max_filler, "maximum")
-def _recursive_set_default_fill_value(dt):
- """
- Create the default fill value for a structured dtype.
-
- Parameters
- ----------
- dt: dtype
- The structured dtype for which to create the fill value.
-
- Returns
- -------
- val: tuple
- A tuple of values corresponding to the default structured fill value.
-
- """
- deflist = []
- for name in dt.names:
- currenttype = dt[name]
- if currenttype.subdtype:
- currenttype = currenttype.subdtype[0]
-
- if currenttype.names:
- deflist.append(
- tuple(_recursive_set_default_fill_value(currenttype)))
- else:
- deflist.append(default_fill_value(currenttype))
- return tuple(deflist)
-
-
def _recursive_set_fill_value(fillvalue, dt):
"""
Create a fill value for a structured dtype.
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 3a401caed..816e149a9 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1788,6 +1788,26 @@ class TestFillingValues(TestCase):
assert_equal(b['a']._data, a._data)
assert_equal(b['a'].fill_value, a.fill_value)
+ def test_default_fill_value(self):
+ # check all calling conventions
+ f1 = default_fill_value(1.)
+ f2 = default_fill_value(np.array(1.))
+ f3 = default_fill_value(np.array(1.).dtype)
+ assert_equal(f1, f2)
+ assert_equal(f1, f3)
+
+ def test_default_fill_value_structured(self):
+ fields = array([(1, 1, 1)],
+ dtype=[('i', int), ('s', '|S8'), ('f', float)])
+
+ f1 = default_fill_value(fields)
+ f2 = default_fill_value(fields.dtype)
+ expected = np.array((default_fill_value(0),
+ default_fill_value('0'),
+ default_fill_value(0.)), dtype=fields.dtype)
+ assert_equal(f1, expected)
+ assert_equal(f2, expected)
+
def test_fillvalue(self):
# Yet more fun with the fill_value
data = masked_array([1, 2, 3], fill_value=-999)
@@ -1863,17 +1883,20 @@ class TestFillingValues(TestCase):
a = array([(1, (2, 3)), (4, (5, 6))],
dtype=[('A', int), ('B', [('BA', int), ('BB', int)])])
test = a.fill_value
+ assert_equal(test.dtype, a.dtype)
assert_equal(test['A'], default_fill_value(a['A']))
assert_equal(test['B']['BA'], default_fill_value(a['B']['BA']))
assert_equal(test['B']['BB'], default_fill_value(a['B']['BB']))
test = minimum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
assert_equal(test[0], minimum_fill_value(a['A']))
assert_equal(test[1][0], minimum_fill_value(a['B']['BA']))
assert_equal(test[1][1], minimum_fill_value(a['B']['BB']))
assert_equal(test[1], minimum_fill_value(a['B']))
test = maximum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
assert_equal(test[0], maximum_fill_value(a['A']))
assert_equal(test[1][0], maximum_fill_value(a['B']['BA']))
assert_equal(test[1][1], maximum_fill_value(a['B']['BB']))
@@ -1883,9 +1906,11 @@ class TestFillingValues(TestCase):
a = array(([2, 3, 4],), dtype=[('value', np.int8, 3)])
test = minimum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
assert_equal(test[0], np.full(3, minimum_fill_value(a['value'])))
test = maximum_fill_value(a)
+ assert_equal(test.dtype, a.dtype)
assert_equal(test[0], np.full(3, maximum_fill_value(a['value'])))
def test_fillvalue_individual_fields(self):