summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_stride_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r--numpy/lib/tests/test_stride_tricks.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index cd0973300..8e0f5f0b9 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -3,7 +3,7 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
run_module_suite, assert_equal, assert_array_equal,
- assert_raises
+ assert_raises, assert_
)
from numpy.lib.stride_tricks import as_strided, broadcast_arrays
@@ -234,5 +234,36 @@ def test_as_strided():
assert_array_equal(a_view, expected)
+class SimpleSubClass(np.ndarray):
+ def __new__(cls, *args, **kwargs):
+ kwargs['subok'] = True
+ self = np.array(*args, **kwargs).view(cls)
+ self.info = 'simple'
+ return self
+
+ def __array_finalize__(self, obj):
+ self.info = getattr(obj, 'info', '') + ' finalized'
+
+
+def test_subclasses():
+ a = SimpleSubClass([1, 2, 3, 4])
+ assert_(type(a) is SimpleSubClass)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
+ assert_(type(a_view) is np.ndarray)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
+ assert_(a_view.info == 'simple finalized')
+ assert_(type(a_view) is SimpleSubClass)
+ b = np.arange(len(a)).reshape(-1, 1)
+ a_view, b_view = broadcast_arrays(a, b)
+ assert_(type(a_view) is np.ndarray)
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+ a_view, b_view = broadcast_arrays(a, b, subok=True)
+ assert_(type(a_view) is SimpleSubClass)
+ assert_(a_view.info == 'simple finalized')
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+
+
if __name__ == "__main__":
run_module_suite()