diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2014-04-15 21:32:42 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2014-08-25 13:08:19 -0400 |
commit | 26a02cd702d9ccfc48978dcf81c80225f324bf3b (patch) | |
tree | 27e5a930f4809457c10dbabc100127c73b0aa2f0 /numpy/lib/tests/test_stride_tricks.py | |
parent | 14e4cc3aa8bc6e01d6494860c7de6bf9ec0ab86b (diff) | |
download | numpy-26a02cd702d9ccfc48978dcf81c80225f324bf3b.tar.gz |
ENH: add subok flag to stride_tricks (and thus broadcast_arrays)
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r-- | numpy/lib/tests/test_stride_tricks.py | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py index cd0973300..8e0f5f0b9 100644 --- a/numpy/lib/tests/test_stride_tricks.py +++ b/numpy/lib/tests/test_stride_tricks.py @@ -3,7 +3,7 @@ from __future__ import division, absolute_import, print_function import numpy as np from numpy.testing import ( run_module_suite, assert_equal, assert_array_equal, - assert_raises + assert_raises, assert_ ) from numpy.lib.stride_tricks import as_strided, broadcast_arrays @@ -234,5 +234,36 @@ def test_as_strided(): assert_array_equal(a_view, expected) +class SimpleSubClass(np.ndarray): + def __new__(cls, *args, **kwargs): + kwargs['subok'] = True + self = np.array(*args, **kwargs).view(cls) + self.info = 'simple' + return self + + def __array_finalize__(self, obj): + self.info = getattr(obj, 'info', '') + ' finalized' + + +def test_subclasses(): + a = SimpleSubClass([1, 2, 3, 4]) + assert_(type(a) is SimpleSubClass) + a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,)) + assert_(type(a_view) is np.ndarray) + a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True) + assert_(a_view.info == 'simple finalized') + assert_(type(a_view) is SimpleSubClass) + b = np.arange(len(a)).reshape(-1, 1) + a_view, b_view = broadcast_arrays(a, b) + assert_(type(a_view) is np.ndarray) + assert_(type(b_view) is np.ndarray) + assert_(a_view.shape == b_view.shape) + a_view, b_view = broadcast_arrays(a, b, subok=True) + assert_(type(a_view) is SimpleSubClass) + assert_(a_view.info == 'simple finalized') + assert_(type(b_view) is np.ndarray) + assert_(a_view.shape == b_view.shape) + + if __name__ == "__main__": run_module_suite() |