diff options
author | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2020-08-04 06:00:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-04 07:00:37 +0300 |
commit | 593ef5fc5a02fbcd6eeb70a59684b3b21c9cc643 (patch) | |
tree | e811d881547f6cab8998b887d9a7dfa4e5965641 /numpy/lib/tests | |
parent | 8f6052280c445328c7f7436917c26fc295429173 (diff) | |
download | numpy-593ef5fc5a02fbcd6eeb70a59684b3b21c9cc643.tar.gz |
ENH: Speed up trim_zeros (#16911)
* Added a benchmark for `trim_zeros()`
* Improve the performance of `np.trim_zeros()`
* Increase the variety of the tests
Fall back to the old `np.trim_zeros()` implementation if an exception is encountered.
Emit a `DeprecationWarning` in such case.
* DEP,REL: Added a deprecation release note
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 46 |
1 files changed, 34 insertions, 12 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index eb2fc3311..89c1a2d9b 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1166,25 +1166,47 @@ class TestAngle: class TestTrimZeros: - """ - Only testing for integer splits. + a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) + b = a.astype(float) + c = a.astype(complex) + d = np.array([None, [], 1, False, 'b', 3.0, range(4), b''], dtype=object) - """ + def values(self): + attr_names = ('a', 'b', 'c', 'd') + return (getattr(self, name) for name in attr_names) def test_basic(self): - a = np.array([0, 0, 1, 2, 3, 4, 0]) - res = trim_zeros(a) - assert_array_equal(res, np.array([1, 2, 3, 4])) + slc = np.s_[2:-1] + for arr in self.values(): + res = trim_zeros(arr) + assert_array_equal(res, arr[slc]) def test_leading_skip(self): - a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) - res = trim_zeros(a) - assert_array_equal(res, np.array([1, 0, 2, 3, 4])) + slc = np.s_[:-1] + for arr in self.values(): + res = trim_zeros(arr, trim='b') + assert_array_equal(res, arr[slc]) def test_trailing_skip(self): - a = np.array([0, 0, 1, 0, 2, 3, 0, 4, 0]) - res = trim_zeros(a) - assert_array_equal(res, np.array([1, 0, 2, 3, 0, 4])) + slc = np.s_[2:] + for arr in self.values(): + res = trim_zeros(arr, trim='F') + assert_array_equal(res, arr[slc]) + + def test_all_zero(self): + for _arr in self.values(): + arr = np.zeros_like(_arr, dtype=_arr.dtype) + + res1 = trim_zeros(arr, trim='B') + assert len(res1) == 0 + + res2 = trim_zeros(arr, trim='f') + assert len(res2) == 0 + + def test_size_zero(self): + arr = np.zeros(0) + res = trim_zeros(arr) + assert_array_equal(arr, res) class TestExtins: |