summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_function_base.py
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2014-02-24 15:22:44 +0100
committerSebastian Berg <sebastian@sipsolutions.net>2014-03-23 20:33:16 +0100
commit123b319be37f01e3c4f2e42552d4ca121b27ca38 (patch)
treed3097f53c4d18e9b5bd6f8d9a8e30879ed4a033f /numpy/lib/tests/test_function_base.py
parent3e00e0058fb28bf22018d0d641f4a51814f5c9bb (diff)
downloadnumpy-123b319be37f01e3c4f2e42552d4ca121b27ca38.tar.gz
ENH: Speed improvements and deprecations for np.select
The idea for this (and some of the code) originally comes from Graeme B Bell (gh-3537). Choose is not as fast and pretty limited, so an iterative copyto is used instead. Closes gh-3259, gh-3537, gh-3551, and gh-3254
Diffstat (limited to 'numpy/lib/tests/test_function_base.py')
-rw-r--r--numpy/lib/tests/test_function_base.py63
1 files changed, 57 insertions, 6 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 3e102cf6a..145a7be4d 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -150,6 +150,13 @@ class TestAverage(TestCase):
class TestSelect(TestCase):
+ choices = [np.array([1, 2, 3]),
+ np.array([4, 5, 6]),
+ np.array([7, 8, 9])]
+ conditions = [np.array([False, False, False]),
+ np.array([False, True, False]),
+ np.array([False, False, True])]
+
def _select(self, cond, values, default=0):
output = []
for m in range(len(cond)):
@@ -157,18 +164,62 @@ class TestSelect(TestCase):
return output
def test_basic(self):
- choices = [np.array([1, 2, 3]),
- np.array([4, 5, 6]),
- np.array([7, 8, 9])]
- conditions = [np.array([0, 0, 0]),
- np.array([0, 1, 0]),
- np.array([0, 0, 1])]
+ choices = self.choices
+ conditions = self.conditions
assert_array_equal(select(conditions, choices, default=15),
self._select(conditions, choices, default=15))
assert_equal(len(choices), 3)
assert_equal(len(conditions), 3)
+ def test_broadcasting(self):
+ conditions = [np.array(True), np.array([False, True, False])]
+ choices = [1, np.arange(12).reshape(4, 3)]
+ assert_array_equal(select(conditions, choices), np.ones((4, 3)))
+ # default can broadcast too:
+ assert_equal(select([True], [0], default=[0]).shape, (1,))
+
+ def test_return_dtype(self):
+ assert_equal(select(self.conditions, self.choices, 1j).dtype,
+ np.complex_)
+ # But the conditions need to be stronger then the scalar default
+ # if it is scalar.
+ choices = [choice.astype(np.int8) for choice in self.choices]
+ assert_equal(select(self.conditions, choices).dtype, np.int8)
+
+ d = np.array([1, 2, 3, np.nan, 5, 7])
+ m = np.isnan(d)
+ assert_equal(select([m], [d]), [0, 0, 0, np.nan, 0, 0])
+
+ def test_deprecated_empty(self):
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter("always")
+ assert_equal(select([], [], 3j), 3j)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("always")
+ assert_warns(DeprecationWarning, select, [], [])
+ warnings.simplefilter("error")
+ assert_raises(DeprecationWarning, select, [], [])
+
+ def test_non_bool_deprecation(self):
+ choices = self.choices
+ conditions = self.conditions[:]
+ with warnings.catch_warnings():
+ warnings.filterwarnings("always")
+ conditions[0] = conditions[0].astype(np.int_)
+ assert_warns(DeprecationWarning, select, conditions, choices)
+ conditions[0] = conditions[0].astype(np.uint8)
+ assert_warns(DeprecationWarning, select, conditions, choices)
+ warnings.filterwarnings("error")
+ assert_raises(DeprecationWarning, select, conditions, choices)
+
+ def test_many_arguments(self):
+ # This used to be limited by NPY_MAXARGS == 32
+ conditions = [np.array([False])] * 100
+ choices = [np.array([1])] * 100
+ select(conditions, choices)
+
class TestInsert(TestCase):
def test_basic(self):