diff options
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index a4f49a78b..e1b615223 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1665,6 +1665,26 @@ class TestVectorize: with assert_raises_regex(ValueError, 'new output dimensions'): f(x) + def test_subclasses(self): + class subclass(np.ndarray): + pass + + m = np.array([[1., 0., 0.], + [0., 0., 1.], + [0., 1., 0.]]).view(subclass) + v = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).view(subclass) + # generalized (gufunc) + matvec = np.vectorize(np.matmul, signature='(m,m),(m)->(m)') + r = matvec(m, v) + assert_equal(type(r), subclass) + assert_equal(r, [[1., 3., 2.], [4., 6., 5.], [7., 9., 8.]]) + + # element-wise (ufunc) + mult = np.vectorize(lambda x, y: x*y) + r = mult(m, v) + assert_equal(type(r), subclass) + assert_equal(r, m * v) + class TestLeaks: class A: @@ -1798,7 +1818,7 @@ class TestUnwrap: assert_array_equal(unwrap([1, 1 + 2 * np.pi]), [1, 1]) # check that unwrap maintains continuity assert_(np.all(diff(unwrap(rand(10) * 100)) < np.pi)) - + def test_period(self): # check that unwrap removes jumps greater that 255 assert_array_equal(unwrap([1, 1 + 256], period=255), [1, 2]) |