summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorBen Rowland <bennyrowland@mac.com>2016-10-11 12:08:47 -0400
committerStephan Hoyer <shoyer@gmail.com>2016-10-11 12:08:47 -0400
commit84b11f57335998e1c2c32e0b1b529ca94895253d (patch)
tree0f6ef23130f6f5d3cf87ff4eadd07c16e2bfb6fa /numpy/lib
parentdbb70947b95b92abf3e48ef8be7bd3ab79489892 (diff)
downloadnumpy-84b11f57335998e1c2c32e0b1b529ca94895253d.tar.gz
ENH: allow numpy.apply_along_axis() to work with ndarray subclasses (#7918)
This commit modifies the numpy.apply_along_axis() function so that if it is called with an ndarray subclass, the internal func1d calls receive subclass instances and the overall function returns an instance of the subclass. There are two new tests for these two behaviours.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py10
-rw-r--r--numpy/lib/tests/test_shape_base.py31
2 files changed, 38 insertions, 3 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 7533f798f..e580690d1 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -74,7 +74,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
[2, 5, 6]])
"""
- arr = asarray(arr)
+ arr = asanyarray(arr)
nd = arr.ndim
if axis < 0:
axis += nd
@@ -109,11 +109,13 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
k += 1
return outarr
else:
+ res = asanyarray(res)
Ntot = product(outshape)
holdshape = outshape
outshape = list(arr.shape)
- outshape[axis] = len(res)
- outarr = zeros(outshape, asarray(res).dtype)
+ outshape[axis] = res.size
+ outarr = zeros(outshape, res.dtype)
+ outarr = res.__array_wrap__(outarr)
outarr[tuple(i.tolist())] = res
k = 1
while k < Ntot:
@@ -128,6 +130,8 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
outarr[tuple(i.tolist())] = res
k += 1
+ if res.shape == ():
+ outarr = outarr.squeeze(axis)
return outarr
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 0177a5729..2eb4a809d 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -27,6 +27,37 @@ class TestApplyAlongAxis(TestCase):
assert_array_equal(apply_along_axis(np.sum, 0, a),
[[27, 30, 33], [36, 39, 42], [45, 48, 51]])
+ def test_preserve_subclass(self):
+ def double(row):
+ return row * 2
+ m = np.matrix([[0, 1], [2, 3]])
+ result = apply_along_axis(double, 0, m)
+ assert isinstance(result, np.matrix)
+ assert_array_equal(
+ result, np.matrix([[0, 2], [4, 6]])
+ )
+
+ def test_subclass(self):
+ class MinimalSubclass(np.ndarray):
+ data = 1
+
+ def minimal_function(array):
+ return array.data
+
+ a = np.zeros((6, 3)).view(MinimalSubclass)
+
+ assert_array_equal(
+ apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1])
+ )
+
+ def test_scalar_array(self):
+ class MinimalSubclass(np.ndarray):
+ pass
+ a = np.ones((6, 3)).view(MinimalSubclass)
+ res = apply_along_axis(np.sum, 0, a)
+ assert isinstance(res, MinimalSubclass)
+ assert_array_equal(res, np.array([6, 6, 6]).view(MinimalSubclass))
+
class TestApplyOverAxes(TestCase):
def test_simple(self):