diff options
author | Tim Leslie <tim.leslie@gmail.com> | 2007-01-08 04:01:52 +0000 |
---|---|---|
committer | Tim Leslie <tim.leslie@gmail.com> | 2007-01-08 04:01:52 +0000 |
commit | 0f1df3fcfcdb4527ff2c332ddf7504b6b60e2813 (patch) | |
tree | 780d627d0361672997d85acdde78c01f4cd6ab49 /numpy | |
parent | 2e832de49f69d26eb7c8d133e45f9b9d99f7a3a6 (diff) | |
download | numpy-0f1df3fcfcdb4527ff2c332ddf7504b6b60e2813.tar.gz |
fix for #407 and add unit test for it
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/shape_base.py | 4 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 4 |
2 files changed, 6 insertions, 2 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index cfe768bfd..32c47ede5 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -31,7 +31,7 @@ def apply_along_axis(func1d,axis,arr,*args): # if res is a number, then we have a smaller output array if isscalar(res): outarr = zeros(outshape,asarray(res).dtype) - outarr[ind] = res + outarr[tuple(ind)] = res Ntot = product(outshape) k = 1 while k < Ntot: @@ -44,7 +44,7 @@ def apply_along_axis(func1d,axis,arr,*args): n -= 1 i.put(indlist,ind) res = func1d(arr[tuple(i.tolist())],*args) - outarr[ind] = res + outarr[tuple(ind)] = res k += 1 return outarr else: diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index b43b08664..416c2644f 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -14,6 +14,10 @@ class test_apply_along_axis(NumpyTestCase): a = ones((10,101),'d') assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1])) + def check_3d(self): + a = arange(27).reshape((3,3,3)) + assert_array_equal(apply_along_axis(sum,0,a), [[27,30,33],[36,39,42],[45,48,51]]) + class test_array_split(NumpyTestCase): def check_integer_0_split(self): a = arange(10) |