diff options
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r-- | numpy/lib/tests/test_stride_tricks.py | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py index cd0973300..8e0f5f0b9 100644 --- a/numpy/lib/tests/test_stride_tricks.py +++ b/numpy/lib/tests/test_stride_tricks.py @@ -3,7 +3,7 @@ from __future__ import division, absolute_import, print_function import numpy as np from numpy.testing import ( run_module_suite, assert_equal, assert_array_equal, - assert_raises + assert_raises, assert_ ) from numpy.lib.stride_tricks import as_strided, broadcast_arrays @@ -234,5 +234,36 @@ def test_as_strided(): assert_array_equal(a_view, expected) +class SimpleSubClass(np.ndarray): + def __new__(cls, *args, **kwargs): + kwargs['subok'] = True + self = np.array(*args, **kwargs).view(cls) + self.info = 'simple' + return self + + def __array_finalize__(self, obj): + self.info = getattr(obj, 'info', '') + ' finalized' + + +def test_subclasses(): + a = SimpleSubClass([1, 2, 3, 4]) + assert_(type(a) is SimpleSubClass) + a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,)) + assert_(type(a_view) is np.ndarray) + a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True) + assert_(a_view.info == 'simple finalized') + assert_(type(a_view) is SimpleSubClass) + b = np.arange(len(a)).reshape(-1, 1) + a_view, b_view = broadcast_arrays(a, b) + assert_(type(a_view) is np.ndarray) + assert_(type(b_view) is np.ndarray) + assert_(a_view.shape == b_view.shape) + a_view, b_view = broadcast_arrays(a, b, subok=True) + assert_(type(a_view) is SimpleSubClass) + assert_(a_view.info == 'simple finalized') + assert_(type(b_view) is np.ndarray) + assert_(a_view.shape == b_view.shape) + + if __name__ == "__main__": run_module_suite() |