summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-10-23 11:18:32 -0600
committerCharles Harris <charlesr.harris@gmail.com>2018-10-24 09:38:13 -0600
commitc8bc1499b5933d189945c1f8956df183c60c8603 (patch)
tree7f3f1c6d2ddb4ec0998b5e78b4bdc6f1de19ac1d /numpy/ma
parent2bdb732eb7d04a6016f5e18efe02365f47a5c6b1 (diff)
downloadnumpy-c8bc1499b5933d189945c1f8956df183c60c8603.tar.gz
BUG: Fix fill value in masked array '==' and '!=' ops.
The type of the fill_value needs to be `bool_` in order to match the result type of `==` and `!=`. Closes #12248.
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 9259aa4c7..9ee44e9ff 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -446,6 +446,7 @@ def _check_fill_value(fill_value, ndtype):
If fill_value is not None, its value is forced to the given dtype.
The result is always a 0d array.
+
"""
ndtype = np.dtype(ndtype)
fields = ndtype.fields
@@ -465,17 +466,19 @@ def _check_fill_value(fill_value, ndtype):
dtype=ndtype)
else:
if isinstance(fill_value, basestring) and (ndtype.char not in 'OSVU'):
+ # Note this check doesn't work if fill_value is not a scalar
err_msg = "Cannot set fill value of string with array of dtype %s"
raise TypeError(err_msg % ndtype)
else:
# In case we want to convert 1e20 to int.
+ # Also in case of converting string arrays.
try:
fill_value = np.array(fill_value, copy=False, dtype=ndtype)
- except OverflowError:
- # Raise TypeError instead of OverflowError. OverflowError
- # is seldom used, and the real problem here is that the
- # passed fill_value is not compatible with the ndtype.
- err_msg = "Fill value %s overflows dtype %s"
+ except (OverflowError, ValueError):
+ # Raise TypeError instead of OverflowError or ValueError.
+ # OverflowError is seldom used, and the real problem here is
+ # that the passed fill_value is not compatible with the ndtype.
+ err_msg = "Cannot convert fill_value %s to dtype %s"
raise TypeError(err_msg % (fill_value, ndtype))
return np.array(fill_value)
@@ -4014,6 +4017,16 @@ class MaskedArray(ndarray):
check = check.view(type(self))
check._update_from(self)
check._mask = mask
+
+ # Cast fill value to bool_ if needed. If it cannot be cast, the
+ # default boolean fill value is used.
+ if check._fill_value is not None:
+ try:
+ fill = _check_fill_value(check._fill_value, np.bool_)
+ except (TypeError, ValueError):
+ fill = _check_fill_value(None, np.bool_)
+ check._fill_value = fill
+
return check
def __eq__(self, other):