diff options
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r-- | numpy/core/shape_base.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 22a8ac304..c65849752 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -399,8 +399,8 @@ def _block_check_depths_match(arrays, parent_index=[]): ) ) elif type(arrays) is list and len(arrays) > 0: - indexes = [_block_check_depths_match(arr, parent_index + [i]) - for i, arr in enumerate(arrays)] + indexes, arr_ndims = zip(*[_block_check_depths_match(arr, parent_index + [i]) + for i, arr in enumerate(arrays)]) first_index = indexes[0] for i, index in enumerate(indexes): @@ -413,16 +413,16 @@ def _block_check_depths_match(arrays, parent_index=[]): format_index(index) ) ) - return first_index + return first_index, max(arr_ndims) elif type(arrays) is list and len(arrays) == 0: # We've 'bottomed out' on an empty list - return parent_index + [None] + return parent_index + [None], _nx.ndim(arrays) else: # We've 'bottomed out' - arrays is either a scalar or an array - return parent_index + return parent_index, _nx.ndim(arrays) -def _block(arrays, max_depth): +def _block(arrays, max_depth, max_ndim): def atleast_nd(a, ndim): # Ensures `a` has at least `ndim` dimensions by prepending # ones to `a.shape` as necessary @@ -433,13 +433,11 @@ def _block(arrays, max_depth): if len(arrays) == 0: raise ValueError('Lists cannot be empty') arrs = [block_recursion(arr, depth+1) for arr in arrays] - arr_ndim = max(arr.ndim for arr in arrs) - arrs = [atleast_nd(a, arr_ndim) for a in arrs] return _nx.concatenate(arrs, axis=-(max_depth-depth)) else: # We've 'bottomed out' - arrays is either a scalar or an array # depth == max_depth - return atleast_nd(arrays, max_depth) + return atleast_nd(arrays, max(max_depth, max_ndim)) return block_recursion(arrays) @@ -592,5 +590,6 @@ def block(arrays): """ - list_ndim = len(_block_check_depths_match(arrays)) - return _block(arrays, list_ndim) + bottom_index, arr_ndim = _block_check_depths_match(arrays) + list_ndim = len(bottom_index) + return _block(arrays, list_ndim, arr_ndim) |