summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/ma/core.py14
-rw-r--r--numpy/ma/tests/test_core.py11
2 files changed, 19 insertions, 6 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 9ee44e9ff..4e6469812 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -449,17 +449,15 @@ def _check_fill_value(fill_value, ndtype):
"""
ndtype = np.dtype(ndtype)
- fields = ndtype.fields
if fill_value is None:
fill_value = default_fill_value(ndtype)
- elif fields:
- fdtype = [(_[0], _[1]) for _ in ndtype.descr]
+ elif ndtype.names is not None:
if isinstance(fill_value, (ndarray, np.void)):
try:
- fill_value = np.array(fill_value, copy=False, dtype=fdtype)
+ fill_value = np.array(fill_value, copy=False, dtype=ndtype)
except ValueError:
err_msg = "Unable to transform %s to dtype %s"
- raise ValueError(err_msg % (fill_value, fdtype))
+ raise ValueError(err_msg % (fill_value, ndtype))
else:
fill_value = np.asarray(fill_value, dtype=object)
fill_value = np.array(_recursive_set_fill_value(fill_value, ndtype),
@@ -780,6 +778,10 @@ def fix_invalid(a, mask=nomask, copy=True, fill_value=None):
a._data[invalid] = fill_value
return a
+def is_string_or_list_of_strings(val):
+ return (isinstance(val, basestring) or
+ (isinstance(val, list) and
+ builtins.all(isinstance(s, basestring) for s in val)))
###############################################################################
# Ufuncs #
@@ -3245,7 +3247,7 @@ class MaskedArray(ndarray):
# Inherit attributes from self
dout._update_from(self)
# Check the fill_value
- if isinstance(indx, basestring):
+ if is_string_or_list_of_strings(indx):
if self._fill_value is not None:
dout._fill_value = self._fill_value[indx]
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 8a015e609..82d7ee960 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -2029,6 +2029,17 @@ class TestFillingValues(object):
assert_equal(x.fill_value, 999.)
assert_equal(x._fill_value, np.array(999.))
+ def test_subarray_fillvalue(self):
+ # gh-10483 test multi-field index fill value
+ fields = array([(1, 1, 1)],
+ dtype=[('i', int), ('s', '|S8'), ('f', float)])
+ with suppress_warnings() as sup:
+ sup.filter(FutureWarning, "Numpy has detected")
+ subfields = fields[['i', 'f']]
+ assert_equal(tuple(subfields.fill_value), (999999, 1.e+20))
+ # test comparison does not raise:
+ subfields[1:] == subfields[:-1]
+
def test_fillvalue_exotic_dtype(self):
# Tests yet more exotic flexible dtypes
_check_fill_value = np.ma.core._check_fill_value