summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2020-08-04 06:00:37 +0200
committerGitHub <noreply@github.com>2020-08-04 07:00:37 +0300
commit593ef5fc5a02fbcd6eeb70a59684b3b21c9cc643 (patch)
treee811d881547f6cab8998b887d9a7dfa4e5965641 /numpy/lib/tests/test_function_base.py
parent8f6052280c445328c7f7436917c26fc295429173 (diff)
downloadnumpy-593ef5fc5a02fbcd6eeb70a59684b3b21c9cc643.tar.gz
ENH: Speed up trim_zeros (#16911)
* Added a benchmark for `trim_zeros()` * Improve the performance of `np.trim_zeros()` * Increase the variety of the tests Fall back to the old `np.trim_zeros()` implementation if an exception is encountered. Emit a `DeprecationWarning` in such case. * DEP,REL: Added a deprecation release note
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py46
1 files changed, 34 insertions, 12 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index eb2fc3311..89c1a2d9b 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1166,25 +1166,47 @@ class TestAngle:
class TestTrimZeros:
- """
- Only testing for integer splits.
+ a = np.array([0, 0, 1, 0, 2, 3, 4, 0])
+ b = a.astype(float)
+ c = a.astype(complex)
+ d = np.array([None, [], 1, False, 'b', 3.0, range(4), b''], dtype=object)
- """
+ def values(self):
+ attr_names = ('a', 'b', 'c', 'd')
+ return (getattr(self, name) for name in attr_names)
def test_basic(self):
- a = np.array([0, 0, 1, 2, 3, 4, 0])
- res = trim_zeros(a)
- assert_array_equal(res, np.array([1, 2, 3, 4]))
+ slc = np.s_[2:-1]
+ for arr in self.values():
+ res = trim_zeros(arr)
+ assert_array_equal(res, arr[slc])
def test_leading_skip(self):
- a = np.array([0, 0, 1, 0, 2, 3, 4, 0])
- res = trim_zeros(a)
- assert_array_equal(res, np.array([1, 0, 2, 3, 4]))
+ slc = np.s_[:-1]
+ for arr in self.values():
+ res = trim_zeros(arr, trim='b')
+ assert_array_equal(res, arr[slc])
def test_trailing_skip(self):
- a = np.array([0, 0, 1, 0, 2, 3, 0, 4, 0])
- res = trim_zeros(a)
- assert_array_equal(res, np.array([1, 0, 2, 3, 0, 4]))
+ slc = np.s_[2:]
+ for arr in self.values():
+ res = trim_zeros(arr, trim='F')
+ assert_array_equal(res, arr[slc])
+
+ def test_all_zero(self):
+ for _arr in self.values():
+ arr = np.zeros_like(_arr, dtype=_arr.dtype)
+
+ res1 = trim_zeros(arr, trim='B')
+ assert len(res1) == 0
+
+ res2 = trim_zeros(arr, trim='f')
+ assert len(res2) == 0
+
+ def test_size_zero(self):
+ arr = np.zeros(0)
+ res = trim_zeros(arr)
+ assert_array_equal(arr, res)
class TestExtins: