summaryrefslogtreecommitdiff
path: root/numpy/core/shape_base.py
diff options
context:
space:
mode:
authorJamie Townsend <jamiehntownsend@gmail.com>2017-10-12 13:54:42 +0100
committerJamie Townsend <jamiehntownsend@gmail.com>2017-10-12 13:54:42 +0100
commitc2b5be5e75bd665e02c97b3d03f559dfde485fbb (patch)
tree540a7b57acf9e3212840c5cc9e58c65333e802cd /numpy/core/shape_base.py
parent5a0557ae36dbba08ce3374f56b8d1502913123e4 (diff)
downloadnumpy-c2b5be5e75bd665e02c97b3d03f559dfde485fbb.tar.gz
Further slight simplifications
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r--numpy/core/shape_base.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index c65849752..732ea30ae 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -399,8 +399,8 @@ def _block_check_depths_match(arrays, parent_index=[]):
)
)
elif type(arrays) is list and len(arrays) > 0:
- indexes, arr_ndims = zip(*[_block_check_depths_match(arr, parent_index + [i])
- for i, arr in enumerate(arrays)])
+ indexes, arr_ndims = zip(*(_block_check_depths_match(arr, parent_index + [i])
+ for i, arr in enumerate(arrays)))
first_index = indexes[0]
for i, index in enumerate(indexes):
@@ -422,14 +422,14 @@ def _block_check_depths_match(arrays, parent_index=[]):
return parent_index, _nx.ndim(arrays)
-def _block(arrays, max_depth, max_ndim):
+def _block(arrays, max_depth, result_ndim):
def atleast_nd(a, ndim):
# Ensures `a` has at least `ndim` dimensions by prepending
# ones to `a.shape` as necessary
return array(a, ndmin=ndim, copy=False, subok=True)
def block_recursion(arrays, depth=0):
- if type(arrays) is list:
+ if depth < max_depth:
if len(arrays) == 0:
raise ValueError('Lists cannot be empty')
arrs = [block_recursion(arr, depth+1) for arr in arrays]
@@ -437,7 +437,7 @@ def _block(arrays, max_depth, max_ndim):
else:
# We've 'bottomed out' - arrays is either a scalar or an array
# depth == max_depth
- return atleast_nd(arrays, max(max_depth, max_ndim))
+ return atleast_nd(arrays, result_ndim)
return block_recursion(arrays)
@@ -592,4 +592,4 @@ def block(arrays):
"""
bottom_index, arr_ndim = _block_check_depths_match(arrays)
list_ndim = len(bottom_index)
- return _block(arrays, list_ndim, arr_ndim)
+ return _block(arrays, list_ndim, max(arr_ndim, list_ndim))