summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_stride_tricks.py
diff options
context:
space:
mode:
authorseberg <sebastian@sipsolutions.net>2014-09-26 18:42:33 +0200
committerseberg <sebastian@sipsolutions.net>2014-09-26 18:42:33 +0200
commit22ae19366cf04a80d3e1ecc039d775656718a668 (patch)
tree4a99a7e7497144f5bd5408e3374fca4d277e3230 /numpy/lib/tests/test_stride_tricks.py
parentbefb246c6d011abc1b4b42ecd8b2e3731b81ce04 (diff)
parente8590311a7b312711c7a4f40c1a15496e34d0ee6 (diff)
downloadnumpy-22ae19366cf04a80d3e1ecc039d775656718a668.tar.gz
Merge pull request #4622 from mhvk/lib/stride_tricks/subclasses
ENH: add subok flag to stride_tricks (and thus broadcast_arrays)
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r--numpy/lib/tests/test_stride_tricks.py46
1 files changed, 45 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index cd0973300..bc7e30ca4 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -3,7 +3,7 @@ from __future__ import division, absolute_import, print_function
import numpy as np
from numpy.testing import (
run_module_suite, assert_equal, assert_array_equal,
- assert_raises
+ assert_raises, assert_
)
from numpy.lib.stride_tricks import as_strided, broadcast_arrays
@@ -234,5 +234,49 @@ def test_as_strided():
assert_array_equal(a_view, expected)
+class VerySimpleSubClass(np.ndarray):
+ def __new__(cls, *args, **kwargs):
+ kwargs['subok'] = True
+ return np.array(*args, **kwargs).view(cls)
+
+
+class SimpleSubClass(VerySimpleSubClass):
+ def __new__(cls, *args, **kwargs):
+ kwargs['subok'] = True
+ self = np.array(*args, **kwargs).view(cls)
+ self.info = 'simple'
+ return self
+
+ def __array_finalize__(self, obj):
+ self.info = getattr(obj, 'info', '') + ' finalized'
+
+
+def test_subclasses():
+ # test that subclass is preserved only if subok=True
+ a = VerySimpleSubClass([1, 2, 3, 4])
+ assert_(type(a) is VerySimpleSubClass)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
+ assert_(type(a_view) is np.ndarray)
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
+ assert_(type(a_view) is VerySimpleSubClass)
+ # test that if a subclass has __array_finalize__, it is used
+ a = SimpleSubClass([1, 2, 3, 4])
+ a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
+ assert_(type(a_view) is SimpleSubClass)
+ assert_(a_view.info == 'simple finalized')
+
+ # similar tests for broadcast_arrays
+ b = np.arange(len(a)).reshape(-1, 1)
+ a_view, b_view = broadcast_arrays(a, b)
+ assert_(type(a_view) is np.ndarray)
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+ a_view, b_view = broadcast_arrays(a, b, subok=True)
+ assert_(type(a_view) is SimpleSubClass)
+ assert_(a_view.info == 'simple finalized')
+ assert_(type(b_view) is np.ndarray)
+ assert_(a_view.shape == b_view.shape)
+
+
if __name__ == "__main__":
run_module_suite()