diff options
author | Matti Picus <matti.picus@gmail.com> | 2020-08-19 13:39:51 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-19 13:39:51 +0300 |
commit | 05a88ad2aacd16f8e38a44fece6588f6ee840a32 (patch) | |
tree | 271faeb5160242100e123296905fc72d4056e62d /numpy/lib/function_base.py | |
parent | 0dc55882a976630d832ebebdb58350d9c26205fe (diff) | |
parent | 26734efdb1c8ee269bb97acc22b587027a5a3f88 (diff) | |
download | numpy-05a88ad2aacd16f8e38a44fece6588f6ee840a32.tar.gz |
Merge pull request #17058 from BvB93/trim_zeros2
MAINT: Revert boolean casting back to elementwise comparisons in `trim_zeros`
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r-- | numpy/lib/function_base.py | 22 |
1 files changed, 14 insertions, 8 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index b530f0aa1..556227c0d 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1631,21 +1631,27 @@ def trim_zeros(filt, trim='fb'): # Numpy 1.20.0, 2020-07-31 warning = DeprecationWarning( "in the future trim_zeros will require a 1-D array as input " - "that is compatible with ndarray.astype(bool)" + "that supports elementwise comparisons with zero" ) warning.__cause__ = ex - warnings.warn(warning, stacklevel=3) - # Fall back to the old implementation if an exception is encountered - # Note that the same exception may or may not be raised here as well - return _trim_zeros_old(filt, trim) + # Fall back to the old implementation if an exception is encountered + # Note that the same exception may or may not be raised here as well + ret = _trim_zeros_old(filt, trim) + warnings.warn(warning, stacklevel=3) + return ret def _trim_zeros_new(filt, trim='fb'): """Newer optimized implementation of ``trim_zeros()``.""" - arr = np.asanyarray(filt).astype(bool, copy=False) - - if arr.ndim != 1: + arr_any = np.asanyarray(filt) + arr = arr_any != 0 if arr_any.dtype != bool else arr_any + + if arr is False: + # not all dtypes support elementwise comparisons with `0` (e.g. str); + # they will return `False` instead + raise TypeError('elementwise comparison failed; unsupported data type') + elif arr.ndim != 1: raise ValueError('trim_zeros requires an array of exactly one dimension') elif not len(arr): return filt |