summaryrefslogtreecommitdiff
path: root/numpy/core/shape_base.py
diff options
context:
space:
mode:
authorJamie Townsend <jamiehntownsend@gmail.com>2017-10-12 10:30:21 +0100
committerJamie Townsend <jamiehntownsend@gmail.com>2017-10-12 10:30:21 +0100
commit5a0557ae36dbba08ce3374f56b8d1502913123e4 (patch)
treeefcb77a8f7d06bc18401dcf1ce5f5b242ea55629 /numpy/core/shape_base.py
parent8a83a5fcf8580fbcb3caf3ab3d5971876d9da959 (diff)
downloadnumpy-5a0557ae36dbba08ce3374f56b8d1502913123e4.tar.gz
Pre-calculate max array ndim
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r--numpy/core/shape_base.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 22a8ac304..c65849752 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -399,8 +399,8 @@ def _block_check_depths_match(arrays, parent_index=[]):
)
)
elif type(arrays) is list and len(arrays) > 0:
- indexes = [_block_check_depths_match(arr, parent_index + [i])
- for i, arr in enumerate(arrays)]
+ indexes, arr_ndims = zip(*[_block_check_depths_match(arr, parent_index + [i])
+ for i, arr in enumerate(arrays)])
first_index = indexes[0]
for i, index in enumerate(indexes):
@@ -413,16 +413,16 @@ def _block_check_depths_match(arrays, parent_index=[]):
format_index(index)
)
)
- return first_index
+ return first_index, max(arr_ndims)
elif type(arrays) is list and len(arrays) == 0:
# We've 'bottomed out' on an empty list
- return parent_index + [None]
+ return parent_index + [None], _nx.ndim(arrays)
else:
# We've 'bottomed out' - arrays is either a scalar or an array
- return parent_index
+ return parent_index, _nx.ndim(arrays)
-def _block(arrays, max_depth):
+def _block(arrays, max_depth, max_ndim):
def atleast_nd(a, ndim):
# Ensures `a` has at least `ndim` dimensions by prepending
# ones to `a.shape` as necessary
@@ -433,13 +433,11 @@ def _block(arrays, max_depth):
if len(arrays) == 0:
raise ValueError('Lists cannot be empty')
arrs = [block_recursion(arr, depth+1) for arr in arrays]
- arr_ndim = max(arr.ndim for arr in arrs)
- arrs = [atleast_nd(a, arr_ndim) for a in arrs]
return _nx.concatenate(arrs, axis=-(max_depth-depth))
else:
# We've 'bottomed out' - arrays is either a scalar or an array
# depth == max_depth
- return atleast_nd(arrays, max_depth)
+ return atleast_nd(arrays, max(max_depth, max_ndim))
return block_recursion(arrays)
@@ -592,5 +590,6 @@ def block(arrays):
"""
- list_ndim = len(_block_check_depths_match(arrays))
- return _block(arrays, list_ndim)
+ bottom_index, arr_ndim = _block_check_depths_match(arrays)
+ list_ndim = len(bottom_index)
+ return _block(arrays, list_ndim, arr_ndim)