diff options
| author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2021-06-26 18:01:48 -0400 |
|---|---|---|
| committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2021-06-26 18:59:33 -0400 |
| commit | 3c892cd4f0809f5a6e736ab234a6a44fc326e29a (patch) | |
| tree | 7912532b552db5a915fef480582efe2060e4f9f4 /numpy/lib/tests | |
| parent | cbec2c8054ea6150490b9e72eb051848b79344d1 (diff) | |
| download | numpy-3c892cd4f0809f5a6e736ab234a6a44fc326e29a.tar.gz | |
API: Ensure np.vectorize outputs can be subclasses.
As is, this is true for the ufunc case, but not for the gufunc case,
even if the underlying function does produce a subclass. Given the
care taken to ensure inputs are kept as subclasses, this is almost
certainly an oversight, which is here corrected.
Diffstat (limited to 'numpy/lib/tests')
| -rw-r--r-- | numpy/lib/tests/test_function_base.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index a4f49a78b..e1b615223 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1665,6 +1665,26 @@ class TestVectorize: with assert_raises_regex(ValueError, 'new output dimensions'): f(x) + def test_subclasses(self): + class subclass(np.ndarray): + pass + + m = np.array([[1., 0., 0.], + [0., 0., 1.], + [0., 1., 0.]]).view(subclass) + v = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).view(subclass) + # generalized (gufunc) + matvec = np.vectorize(np.matmul, signature='(m,m),(m)->(m)') + r = matvec(m, v) + assert_equal(type(r), subclass) + assert_equal(r, [[1., 3., 2.], [4., 6., 5.], [7., 9., 8.]]) + + # element-wise (ufunc) + mult = np.vectorize(lambda x, y: x*y) + r = mult(m, v) + assert_equal(type(r), subclass) + assert_equal(r, m * v) + class TestLeaks: class A: @@ -1798,7 +1818,7 @@ class TestUnwrap: assert_array_equal(unwrap([1, 1 + 2 * np.pi]), [1, 1]) # check that unwrap maintains continuity assert_(np.all(diff(unwrap(rand(10) * 100)) < np.pi)) - + def test_period(self): # check that unwrap removes jumps greater that 255 assert_array_equal(unwrap([1, 1 + 256], period=255), [1, 2]) |
