summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-10-25 12:44:26 -0700
committerStephan Hoyer <shoyer@google.com>2018-10-25 12:44:26 -0700
commit4701c31e7e78b3498ed3b959a135f51c15f552d7 (patch)
treeca918e9b9ca3eba14fce72aa362c5b9e699bdf88
parentdfab760b4a328d9fa29cef123e0fe8e2926b0c8c (diff)
downloadnumpy-4701c31e7e78b3498ed3b959a135f51c15f552d7.tar.gz
TST: tests for _block_dispatcher
-rw-r--r--numpy/core/tests/test_shape_base.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index df819b73f..37112795a 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -6,6 +6,7 @@ from numpy.core import (
array, arange, atleast_1d, atleast_2d, atleast_3d, block, vstack, hstack,
newaxis, concatenate, stack
)
+from numpy.core.shape_base import _block_dispatcher
from numpy.testing import (
assert_, assert_raises, assert_array_equal, assert_equal,
assert_raises_regex, assert_almost_equal
@@ -592,3 +593,16 @@ class TestBlock(object):
[3., 3., 3.]]])
assert_equal(result, expected)
+
+ def test_block_dispatcher(self):
+ class MyArray(object):
+ __array_function__ = None
+ a = MyArray()
+ b = MyArray()
+ c = MyArray()
+ assert_equal(list(_block_dispatcher(a)), [a])
+ assert_equal(list(_block_dispatcher([a])), [a])
+ assert_equal(list(_block_dispatcher([a, b])), [a, b])
+ assert_equal(list(_block_dispatcher([[a], [b, [c]]])), [a, b, c])
+ # don't recurse into non-lists
+ assert_equal(list(_block_dispatcher((a, b))), [(a, b)])