diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/shape_base.py | 4 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 4 |
2 files changed, 6 insertions, 2 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index cfe768bfd..32c47ede5 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -31,7 +31,7 @@ def apply_along_axis(func1d,axis,arr,*args): # if res is a number, then we have a smaller output array if isscalar(res): outarr = zeros(outshape,asarray(res).dtype) - outarr[ind] = res + outarr[tuple(ind)] = res Ntot = product(outshape) k = 1 while k < Ntot: @@ -44,7 +44,7 @@ def apply_along_axis(func1d,axis,arr,*args): n -= 1 i.put(indlist,ind) res = func1d(arr[tuple(i.tolist())],*args) - outarr[ind] = res + outarr[tuple(ind)] = res k += 1 return outarr else: diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index b43b08664..416c2644f 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -14,6 +14,10 @@ class test_apply_along_axis(NumpyTestCase): a = ones((10,101),'d') assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1])) + def check_3d(self): + a = arange(27).reshape((3,3,3)) + assert_array_equal(apply_along_axis(sum,0,a), [[27,30,33],[36,39,42],[45,48,51]]) + class test_array_split(NumpyTestCase): def check_integer_0_split(self): a = arange(10) |