diff options
author | Jamie Townsend <jamiehntownsend@gmail.com> | 2017-10-12 10:30:21 +0100 |
---|---|---|
committer | Jamie Townsend <jamiehntownsend@gmail.com> | 2017-10-12 10:30:21 +0100 |
commit | 5a0557ae36dbba08ce3374f56b8d1502913123e4 (patch) | |
tree | efcb77a8f7d06bc18401dcf1ce5f5b242ea55629 /numpy/core/shape_base.py | |
parent | 8a83a5fcf8580fbcb3caf3ab3d5971876d9da959 (diff) | |
download | numpy-5a0557ae36dbba08ce3374f56b8d1502913123e4.tar.gz |
Pre-calculate max array ndim
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r-- | numpy/core/shape_base.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 22a8ac304..c65849752 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -399,8 +399,8 @@ def _block_check_depths_match(arrays, parent_index=[]): ) ) elif type(arrays) is list and len(arrays) > 0: - indexes = [_block_check_depths_match(arr, parent_index + [i]) - for i, arr in enumerate(arrays)] + indexes, arr_ndims = zip(*[_block_check_depths_match(arr, parent_index + [i]) + for i, arr in enumerate(arrays)]) first_index = indexes[0] for i, index in enumerate(indexes): @@ -413,16 +413,16 @@ def _block_check_depths_match(arrays, parent_index=[]): format_index(index) ) ) - return first_index + return first_index, max(arr_ndims) elif type(arrays) is list and len(arrays) == 0: # We've 'bottomed out' on an empty list - return parent_index + [None] + return parent_index + [None], _nx.ndim(arrays) else: # We've 'bottomed out' - arrays is either a scalar or an array - return parent_index + return parent_index, _nx.ndim(arrays) -def _block(arrays, max_depth): +def _block(arrays, max_depth, max_ndim): def atleast_nd(a, ndim): # Ensures `a` has at least `ndim` dimensions by prepending # ones to `a.shape` as necessary @@ -433,13 +433,11 @@ def _block(arrays, max_depth): if len(arrays) == 0: raise ValueError('Lists cannot be empty') arrs = [block_recursion(arr, depth+1) for arr in arrays] - arr_ndim = max(arr.ndim for arr in arrs) - arrs = [atleast_nd(a, arr_ndim) for a in arrs] return _nx.concatenate(arrs, axis=-(max_depth-depth)) else: # We've 'bottomed out' - arrays is either a scalar or an array # depth == max_depth - return atleast_nd(arrays, max_depth) + return atleast_nd(arrays, max(max_depth, max_ndim)) return block_recursion(arrays) @@ -592,5 +590,6 @@ def block(arrays): """ - list_ndim = len(_block_check_depths_match(arrays)) - return _block(arrays, list_ndim) + bottom_index, arr_ndim = _block_check_depths_match(arrays) + list_ndim = len(bottom_index) + return _block(arrays, list_ndim, arr_ndim) |