From 593ef5fc5a02fbcd6eeb70a59684b3b21c9cc643 Mon Sep 17 00:00:00 2001 From: Bas van Beek <43369155+BvB93@users.noreply.github.com> Date: Tue, 4 Aug 2020 06:00:37 +0200 Subject: ENH: Speed up trim_zeros (#16911) * Added a benchmark for `trim_zeros()` * Improve the performance of `np.trim_zeros()` * Increase the variety of the tests Fall back to the old `np.trim_zeros()` implementation if an exception is encountered. Emit a `DeprecationWarning` in such case. * DEP,REL: Added a deprecation release note --- numpy/lib/tests/test_function_base.py | 46 ++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 12 deletions(-) (limited to 'numpy/lib/tests/test_function_base.py') diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index eb2fc3311..89c1a2d9b 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1166,25 +1166,47 @@ class TestAngle: class TestTrimZeros: - """ - Only testing for integer splits. + a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) + b = a.astype(float) + c = a.astype(complex) + d = np.array([None, [], 1, False, 'b', 3.0, range(4), b''], dtype=object) - """ + def values(self): + attr_names = ('a', 'b', 'c', 'd') + return (getattr(self, name) for name in attr_names) def test_basic(self): - a = np.array([0, 0, 1, 2, 3, 4, 0]) - res = trim_zeros(a) - assert_array_equal(res, np.array([1, 2, 3, 4])) + slc = np.s_[2:-1] + for arr in self.values(): + res = trim_zeros(arr) + assert_array_equal(res, arr[slc]) def test_leading_skip(self): - a = np.array([0, 0, 1, 0, 2, 3, 4, 0]) - res = trim_zeros(a) - assert_array_equal(res, np.array([1, 0, 2, 3, 4])) + slc = np.s_[:-1] + for arr in self.values(): + res = trim_zeros(arr, trim='b') + assert_array_equal(res, arr[slc]) def test_trailing_skip(self): - a = np.array([0, 0, 1, 0, 2, 3, 0, 4, 0]) - res = trim_zeros(a) - assert_array_equal(res, np.array([1, 0, 2, 3, 0, 4])) + slc = np.s_[2:] + for arr in self.values(): + res = trim_zeros(arr, trim='F') + assert_array_equal(res, arr[slc]) + + def test_all_zero(self): + for _arr in self.values(): + arr = np.zeros_like(_arr, dtype=_arr.dtype) + + res1 = trim_zeros(arr, trim='B') + assert len(res1) == 0 + + res2 = trim_zeros(arr, trim='f') + assert len(res2) == 0 + + def test_size_zero(self): + arr = np.zeros(0) + res = trim_zeros(arr) + assert_array_equal(arr, res) class TestExtins: -- cgit v1.2.1