summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2007-05-28 11:51:55 +0000
committerStefan van der Walt <stefan@sun.ac.za>2007-05-28 11:51:55 +0000
commit944c32ad4a0618c834dcb06e50e90267df1d6835 (patch)
treec5009067c582ec226b69e2267bb37dfa4c245d3d /numpy/lib
parent840bd64e600ac458b17fd058a181b860e87d56bd (diff)
downloadnumpy-944c32ad4a0618c834dcb06e50e90267df1d6835.tar.gz
Select should not modify output arguments. Add test for basic select functionality.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py6
-rw-r--r--numpy/lib/tests/test_function_base.py16
2 files changed, 17 insertions, 5 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index e038a4803..5d0d7c1a6 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -286,9 +286,7 @@ def histogramdd(sample, bins=10, range=None, normed=False, weights=None):
def average(a, axis=None, weights=None, returned=False):
- """average(a, axis=None weights=None, returned=False)
-
- Average the array over the given axis. If the axis is None,
+ """Average the array over the given axis. If the axis is None,
average over all dimensions of the array. Equivalent to
a.mean(axis) and to
@@ -452,7 +450,7 @@ def select(condlist, choicelist, default=0):
n2 = len(choicelist)
if n2 != n:
raise ValueError, "list of cases must be same length as list of conditions"
- choicelist.insert(0, default)
+ choicelist = [default] + choicelist
S = 0
pfac = 1
for k in range(1, n+1):
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index f0930ae5b..b22ce1318 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -64,6 +64,20 @@ class test_average(NumpyTestCase):
desired = array([3.,4.,5.])
assert_array_equal(actual, desired)
+class test_select(NumpyTestCase):
+ def check_basic(self):
+ choices = [array([1,2,3]),
+ array([4,5,6]),
+ array([7,8,9])]
+ conditions = [array([0,0,0]),
+ array([0,1,0]),
+ array([0,0,1])]
+ assert_array_equal(select(conditions,choices,default=15),
+ [15,5,9])
+
+ assert_equal(len(choices),3)
+ assert_equal(len(conditions),3)
+
class test_logspace(NumpyTestCase):
def check_basic(self):
y = logspace(0,6)
@@ -431,4 +445,4 @@ def compare_results(res,desired):
assert_array_equal(res[i],desired[i])
if __name__ == "__main__":
- NumpyTest('numpy.lib.function_base').run()
+ NumpyTest().run()