diff options
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 94 |
1 files changed, 85 insertions, 9 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index a716d3b38..530472a9b 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -28,14 +28,20 @@ class TestApplyAlongAxis(TestCase): [[27, 30, 33], [36, 39, 42], [45, 48, 51]]) def test_preserve_subclass(self): + # this test is particularly malicious because matrix + # refuses to become 1d def double(row): return row * 2 m = np.matrix([[0, 1], [2, 3]]) + expected = np.matrix([[0, 2], [4, 6]]) + result = apply_along_axis(double, 0, m) - assert isinstance(result, np.matrix) - assert_array_equal( - result, np.matrix([[0, 2], [4, 6]]) - ) + assert_(isinstance(result, np.matrix)) + assert_array_equal(result, expected) + + result = apply_along_axis(double, 1, m) + assert_(isinstance(result, np.matrix)) + assert_array_equal(result, expected) def test_subclass(self): class MinimalSubclass(np.ndarray): @@ -50,13 +56,83 @@ class TestApplyAlongAxis(TestCase): apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1]) ) - def test_scalar_array(self): + def test_scalar_array(self, cls=np.ndarray): + a = np.ones((6, 3)).view(cls) + res = apply_along_axis(np.sum, 0, a) + assert_(isinstance(res, cls)) + assert_array_equal(res, np.array([6, 6, 6]).view(cls)) + + def test_0d_array(self, cls=np.ndarray): + def sum_to_0d(x): + """ Sum x, returning a 0d array of the same class """ + assert_equal(x.ndim, 1) + return np.squeeze(np.sum(x, keepdims=True)) + a = np.ones((6, 3)).view(cls) + res = apply_along_axis(sum_to_0d, 0, a) + assert_(isinstance(res, cls)) + assert_array_equal(res, np.array([6, 6, 6]).view(cls)) + + res = apply_along_axis(sum_to_0d, 1, a) + assert_(isinstance(res, cls)) + assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls)) + + def test_axis_insertion(self, cls=np.ndarray): + def f1to2(x): + """produces an assymmetric non-square matrix from x""" + assert_equal(x.ndim, 1) + return (x[::-1] * x[1:,None]).view(cls) + + a2d = np.arange(6*3).reshape((6, 3)) + + # 2d insertion along first axis + actual = apply_along_axis(f1to2, 0, a2d) + expected = np.stack([ + f1to2(a2d[:,i]) for i in range(a2d.shape[1]) + ], axis=-1).view(cls) + assert_equal(type(actual), type(expected)) + assert_equal(actual, expected) + + # 2d insertion along last axis + actual = apply_along_axis(f1to2, 1, a2d) + expected = np.stack([ + f1to2(a2d[i,:]) for i in range(a2d.shape[0]) + ], axis=0).view(cls) + assert_equal(type(actual), type(expected)) + assert_equal(actual, expected) + + # 3d insertion along middle axis + a3d = np.arange(6*5*3).reshape((6, 5, 3)) + + actual = apply_along_axis(f1to2, 1, a3d) + expected = np.stack([ + np.stack([ + f1to2(a3d[i,:,j]) for i in range(a3d.shape[0]) + ], axis=0) + for j in range(a3d.shape[2]) + ], axis=-1).view(cls) + assert_equal(type(actual), type(expected)) + assert_equal(actual, expected) + + def test_subclass_preservation(self): class MinimalSubclass(np.ndarray): pass - a = np.ones((6, 3)).view(MinimalSubclass) - res = apply_along_axis(np.sum, 0, a) - assert isinstance(res, MinimalSubclass) - assert_array_equal(res, np.array([6, 6, 6]).view(MinimalSubclass)) + self.test_scalar_array(MinimalSubclass) + self.test_0d_array(MinimalSubclass) + self.test_axis_insertion(MinimalSubclass) + + def test_axis_insertion_ma(self): + def f1to2(x): + """produces an assymmetric non-square matrix from x""" + assert_equal(x.ndim, 1) + res = x[::-1] * x[1:,None] + return np.ma.masked_where(res%5==0, res) + a = np.arange(6*3).reshape((6, 3)) + res = apply_along_axis(f1to2, 0, a) + assert_(isinstance(res, np.ma.masked_array)) + assert_equal(res.ndim, 3) + assert_array_equal(res[:,:,0].mask, f1to2(a[:,0]).mask) + assert_array_equal(res[:,:,1].mask, f1to2(a[:,1]).mask) + assert_array_equal(res[:,:,2].mask, f1to2(a[:,2]).mask) def test_tuple_func1d(self): def sample_1d(x): |