diff options
author | Ben Rowland <bennyrowland@mac.com> | 2016-10-11 12:08:47 -0400 |
---|---|---|
committer | Stephan Hoyer <shoyer@gmail.com> | 2016-10-11 12:08:47 -0400 |
commit | 84b11f57335998e1c2c32e0b1b529ca94895253d (patch) | |
tree | 0f6ef23130f6f5d3cf87ff4eadd07c16e2bfb6fa /numpy/lib | |
parent | dbb70947b95b92abf3e48ef8be7bd3ab79489892 (diff) | |
download | numpy-84b11f57335998e1c2c32e0b1b529ca94895253d.tar.gz |
ENH: allow numpy.apply_along_axis() to work with ndarray subclasses (#7918)
This commit modifies the numpy.apply_along_axis() function so that if
it is called with an ndarray subclass, the internal func1d calls
receive subclass instances and the overall function returns an instance
of the subclass. There are two new tests for these two behaviours.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/shape_base.py | 10 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 31 |
2 files changed, 38 insertions, 3 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 7533f798f..e580690d1 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -74,7 +74,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): [2, 5, 6]]) """ - arr = asarray(arr) + arr = asanyarray(arr) nd = arr.ndim if axis < 0: axis += nd @@ -109,11 +109,13 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): k += 1 return outarr else: + res = asanyarray(res) Ntot = product(outshape) holdshape = outshape outshape = list(arr.shape) - outshape[axis] = len(res) - outarr = zeros(outshape, asarray(res).dtype) + outshape[axis] = res.size + outarr = zeros(outshape, res.dtype) + outarr = res.__array_wrap__(outarr) outarr[tuple(i.tolist())] = res k = 1 while k < Ntot: @@ -128,6 +130,8 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): res = func1d(arr[tuple(i.tolist())], *args, **kwargs) outarr[tuple(i.tolist())] = res k += 1 + if res.shape == (): + outarr = outarr.squeeze(axis) return outarr diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 0177a5729..2eb4a809d 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -27,6 +27,37 @@ class TestApplyAlongAxis(TestCase): assert_array_equal(apply_along_axis(np.sum, 0, a), [[27, 30, 33], [36, 39, 42], [45, 48, 51]]) + def test_preserve_subclass(self): + def double(row): + return row * 2 + m = np.matrix([[0, 1], [2, 3]]) + result = apply_along_axis(double, 0, m) + assert isinstance(result, np.matrix) + assert_array_equal( + result, np.matrix([[0, 2], [4, 6]]) + ) + + def test_subclass(self): + class MinimalSubclass(np.ndarray): + data = 1 + + def minimal_function(array): + return array.data + + a = np.zeros((6, 3)).view(MinimalSubclass) + + assert_array_equal( + apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1]) + ) + + def test_scalar_array(self): + class MinimalSubclass(np.ndarray): + pass + a = np.ones((6, 3)).view(MinimalSubclass) + res = apply_along_axis(np.sum, 0, a) + assert isinstance(res, MinimalSubclass) + assert_array_equal(res, np.array([6, 6, 6]).view(MinimalSubclass)) + class TestApplyOverAxes(TestCase): def test_simple(self): |