summaryrefslogtreecommitdiff
path: root/numpy/ma/core.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-10-26 09:35:23 -0500
committerGitHub <noreply@github.com>2018-10-26 09:35:23 -0500
commit3debe9772ea1b68d997dba3440929a467ad11c52 (patch)
treee9eaef760e19f86622cafa19ff64a0a46a60fbc3 /numpy/ma/core.py
parent763be37f77b572f46637db4032cebfd0f96175d9 (diff)
parent173f65c02c8d654fc55f381665ec43c64d43d636 (diff)
downloadnumpy-3debe9772ea1b68d997dba3440929a467ad11c52.tar.gz
Merge pull request #12257 from charris/fix-ma-compare-fill_value
BUG: Fix fill value in masked array '==' and '!=' ops.
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r--numpy/ma/core.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 9259aa4c7..9ee44e9ff 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -446,6 +446,7 @@ def _check_fill_value(fill_value, ndtype):
If fill_value is not None, its value is forced to the given dtype.
The result is always a 0d array.
+
"""
ndtype = np.dtype(ndtype)
fields = ndtype.fields
@@ -465,17 +466,19 @@ def _check_fill_value(fill_value, ndtype):
dtype=ndtype)
else:
if isinstance(fill_value, basestring) and (ndtype.char not in 'OSVU'):
+ # Note this check doesn't work if fill_value is not a scalar
err_msg = "Cannot set fill value of string with array of dtype %s"
raise TypeError(err_msg % ndtype)
else:
# In case we want to convert 1e20 to int.
+ # Also in case of converting string arrays.
try:
fill_value = np.array(fill_value, copy=False, dtype=ndtype)
- except OverflowError:
- # Raise TypeError instead of OverflowError. OverflowError
- # is seldom used, and the real problem here is that the
- # passed fill_value is not compatible with the ndtype.
- err_msg = "Fill value %s overflows dtype %s"
+ except (OverflowError, ValueError):
+ # Raise TypeError instead of OverflowError or ValueError.
+ # OverflowError is seldom used, and the real problem here is
+ # that the passed fill_value is not compatible with the ndtype.
+ err_msg = "Cannot convert fill_value %s to dtype %s"
raise TypeError(err_msg % (fill_value, ndtype))
return np.array(fill_value)
@@ -4014,6 +4017,16 @@ class MaskedArray(ndarray):
check = check.view(type(self))
check._update_from(self)
check._mask = mask
+
+ # Cast fill value to bool_ if needed. If it cannot be cast, the
+ # default boolean fill value is used.
+ if check._fill_value is not None:
+ try:
+ fill = _check_fill_value(check._fill_value, np.bool_)
+ except (TypeError, ValueError):
+ fill = _check_fill_value(None, np.bool_)
+ check._fill_value = fill
+
return check
def __eq__(self, other):