summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTim Leslie <tim.leslie@gmail.com>2007-01-08 04:01:52 +0000
committerTim Leslie <tim.leslie@gmail.com>2007-01-08 04:01:52 +0000
commit0f1df3fcfcdb4527ff2c332ddf7504b6b60e2813 (patch)
tree780d627d0361672997d85acdde78c01f4cd6ab49 /numpy
parent2e832de49f69d26eb7c8d133e45f9b9d99f7a3a6 (diff)
downloadnumpy-0f1df3fcfcdb4527ff2c332ddf7504b6b60e2813.tar.gz
fix for #407 and add unit test for it
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/shape_base.py4
-rw-r--r--numpy/lib/tests/test_shape_base.py4
2 files changed, 6 insertions, 2 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index cfe768bfd..32c47ede5 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -31,7 +31,7 @@ def apply_along_axis(func1d,axis,arr,*args):
# if res is a number, then we have a smaller output array
if isscalar(res):
outarr = zeros(outshape,asarray(res).dtype)
- outarr[ind] = res
+ outarr[tuple(ind)] = res
Ntot = product(outshape)
k = 1
while k < Ntot:
@@ -44,7 +44,7 @@ def apply_along_axis(func1d,axis,arr,*args):
n -= 1
i.put(indlist,ind)
res = func1d(arr[tuple(i.tolist())],*args)
- outarr[ind] = res
+ outarr[tuple(ind)] = res
k += 1
return outarr
else:
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index b43b08664..416c2644f 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -14,6 +14,10 @@ class test_apply_along_axis(NumpyTestCase):
a = ones((10,101),'d')
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
+ def check_3d(self):
+ a = arange(27).reshape((3,3,3))
+ assert_array_equal(apply_along_axis(sum,0,a), [[27,30,33],[36,39,42],[45,48,51]])
+
class test_array_split(NumpyTestCase):
def check_integer_0_split(self):
a = arange(10)