summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_stride_tricks.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2014-04-15 21:32:42 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2014-08-25 13:08:19 -0400
commit26a02cd702d9ccfc48978dcf81c80225f324bf3b (patch)
tree27e5a930f4809457c10dbabc100127c73b0aa2f0 /numpy/lib/tests/test_stride_tricks.py
parent14e4cc3aa8bc6e01d6494860c7de6bf9ec0ab86b (diff)
downloadnumpy-26a02cd702d9ccfc48978dcf81c80225f324bf3b.tar.gz
ENH: add subok flag to stride_tricks (and thus broadcast_arrays)
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r--numpy/lib/tests/test_stride_tricks.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index cd0973300..8e0f5f0b9 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -3,7 +3,7 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
run_module_suite, assert_equal, assert_array_equal,
- assert_raises
+ assert_raises, assert_
)
from numpy.lib.stride_tricks import as_strided, broadcast_arrays
@@ -234,5 +234,36 @@ def test_as_strided():
assert_array_equal(a_view, expected)
+class SimpleSubClass(np.ndarray):
+ def __new__(cls, *args, **kwargs):
+ kwargs['subok'] = True
+ self = np.array(*args, **kwargs).view(cls)
+ self.info = 'simple'
+ return self
+
+ def __array_finalize__(self, obj):
+ self.info = getattr(obj, 'info', '') + ' finalized'
+
+
+def test_subclasses():
+ a = SimpleSubClass([1, 2, 3, 4])
+ assert_(type(a) is SimpleSubClass)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
+ assert_(type(a_view) is np.ndarray)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
+ assert_(a_view.info == 'simple finalized')
+ assert_(type(a_view) is SimpleSubClass)
+ b = np.arange(len(a)).reshape(-1, 1)
+ a_view, b_view = broadcast_arrays(a, b)
+ assert_(type(a_view) is np.ndarray)
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+ a_view, b_view = broadcast_arrays(a, b, subok=True)
+ assert_(type(a_view) is SimpleSubClass)
+ assert_(a_view.info == 'simple finalized')
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+
+
if __name__ == "__main__":
run_module_suite()