summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2020-08-19 13:39:51 +0300
committerGitHub <noreply@github.com>2020-08-19 13:39:51 +0300
commit05a88ad2aacd16f8e38a44fece6588f6ee840a32 (patch)
tree271faeb5160242100e123296905fc72d4056e62d /numpy/lib/function_base.py
parent0dc55882a976630d832ebebdb58350d9c26205fe (diff)
parent26734efdb1c8ee269bb97acc22b587027a5a3f88 (diff)
downloadnumpy-05a88ad2aacd16f8e38a44fece6588f6ee840a32.tar.gz
Merge pull request #17058 from BvB93/trim_zeros2
MAINT: Revert boolean casting back to elementwise comparisons in `trim_zeros`
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py22
1 files changed, 14 insertions, 8 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index b530f0aa1..556227c0d 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -1631,21 +1631,27 @@ def trim_zeros(filt, trim='fb'):
# Numpy 1.20.0, 2020-07-31
warning = DeprecationWarning(
"in the future trim_zeros will require a 1-D array as input "
- "that is compatible with ndarray.astype(bool)"
+ "that supports elementwise comparisons with zero"
)
warning.__cause__ = ex
- warnings.warn(warning, stacklevel=3)
- # Fall back to the old implementation if an exception is encountered
- # Note that the same exception may or may not be raised here as well
- return _trim_zeros_old(filt, trim)
+ # Fall back to the old implementation if an exception is encountered
+ # Note that the same exception may or may not be raised here as well
+ ret = _trim_zeros_old(filt, trim)
+ warnings.warn(warning, stacklevel=3)
+ return ret
def _trim_zeros_new(filt, trim='fb'):
"""Newer optimized implementation of ``trim_zeros()``."""
- arr = np.asanyarray(filt).astype(bool, copy=False)
-
- if arr.ndim != 1:
+ arr_any = np.asanyarray(filt)
+ arr = arr_any != 0 if arr_any.dtype != bool else arr_any
+
+ if arr is False:
+ # not all dtypes support elementwise comparisons with `0` (e.g. str);
+ # they will return `False` instead
+ raise TypeError('elementwise comparison failed; unsupported data type')
+ elif arr.ndim != 1:
raise ValueError('trim_zeros requires an array of exactly one dimension')
elif not len(arr):
return filt