diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2014-02-24 15:22:44 +0100 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2014-03-23 20:33:16 +0100 |
commit | 123b319be37f01e3c4f2e42552d4ca121b27ca38 (patch) | |
tree | d3097f53c4d18e9b5bd6f8d9a8e30879ed4a033f /numpy/lib/tests/test_function_base.py | |
parent | 3e00e0058fb28bf22018d0d641f4a51814f5c9bb (diff) | |
download | numpy-123b319be37f01e3c4f2e42552d4ca121b27ca38.tar.gz |
ENH: Speed improvements and deprecations for np.select
The idea for this (and some of the code) originally comes from
Graeme B Bell (gh-3537).
Choose is not as fast and pretty limited, so an iterative
copyto is used instead.
Closes gh-3259, gh-3537, gh-3551, and gh-3254
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 63 |
1 files changed, 57 insertions, 6 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 3e102cf6a..145a7be4d 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -150,6 +150,13 @@ class TestAverage(TestCase): class TestSelect(TestCase): + choices = [np.array([1, 2, 3]), + np.array([4, 5, 6]), + np.array([7, 8, 9])] + conditions = [np.array([False, False, False]), + np.array([False, True, False]), + np.array([False, False, True])] + def _select(self, cond, values, default=0): output = [] for m in range(len(cond)): @@ -157,18 +164,62 @@ class TestSelect(TestCase): return output def test_basic(self): - choices = [np.array([1, 2, 3]), - np.array([4, 5, 6]), - np.array([7, 8, 9])] - conditions = [np.array([0, 0, 0]), - np.array([0, 1, 0]), - np.array([0, 0, 1])] + choices = self.choices + conditions = self.conditions assert_array_equal(select(conditions, choices, default=15), self._select(conditions, choices, default=15)) assert_equal(len(choices), 3) assert_equal(len(conditions), 3) + def test_broadcasting(self): + conditions = [np.array(True), np.array([False, True, False])] + choices = [1, np.arange(12).reshape(4, 3)] + assert_array_equal(select(conditions, choices), np.ones((4, 3))) + # default can broadcast too: + assert_equal(select([True], [0], default=[0]).shape, (1,)) + + def test_return_dtype(self): + assert_equal(select(self.conditions, self.choices, 1j).dtype, + np.complex_) + # But the conditions need to be stronger then the scalar default + # if it is scalar. + choices = [choice.astype(np.int8) for choice in self.choices] + assert_equal(select(self.conditions, choices).dtype, np.int8) + + d = np.array([1, 2, 3, np.nan, 5, 7]) + m = np.isnan(d) + assert_equal(select([m], [d]), [0, 0, 0, np.nan, 0, 0]) + + def test_deprecated_empty(self): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + assert_equal(select([], [], 3j), 3j) + + with warnings.catch_warnings(): + warnings.simplefilter("always") + assert_warns(DeprecationWarning, select, [], []) + warnings.simplefilter("error") + assert_raises(DeprecationWarning, select, [], []) + + def test_non_bool_deprecation(self): + choices = self.choices + conditions = self.conditions[:] + with warnings.catch_warnings(): + warnings.filterwarnings("always") + conditions[0] = conditions[0].astype(np.int_) + assert_warns(DeprecationWarning, select, conditions, choices) + conditions[0] = conditions[0].astype(np.uint8) + assert_warns(DeprecationWarning, select, conditions, choices) + warnings.filterwarnings("error") + assert_raises(DeprecationWarning, select, conditions, choices) + + def test_many_arguments(self): + # This used to be limited by NPY_MAXARGS == 32 + conditions = [np.array([False])] * 100 + choices = [np.array([1])] * 100 + select(conditions, choices) + class TestInsert(TestCase): def test_basic(self): |