diff options
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r-- | numpy/core/shape_base.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index c65849752..732ea30ae 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -399,8 +399,8 @@ def _block_check_depths_match(arrays, parent_index=[]): ) ) elif type(arrays) is list and len(arrays) > 0: - indexes, arr_ndims = zip(*[_block_check_depths_match(arr, parent_index + [i]) - for i, arr in enumerate(arrays)]) + indexes, arr_ndims = zip(*(_block_check_depths_match(arr, parent_index + [i]) + for i, arr in enumerate(arrays))) first_index = indexes[0] for i, index in enumerate(indexes): @@ -422,14 +422,14 @@ def _block_check_depths_match(arrays, parent_index=[]): return parent_index, _nx.ndim(arrays) -def _block(arrays, max_depth, max_ndim): +def _block(arrays, max_depth, result_ndim): def atleast_nd(a, ndim): # Ensures `a` has at least `ndim` dimensions by prepending # ones to `a.shape` as necessary return array(a, ndmin=ndim, copy=False, subok=True) def block_recursion(arrays, depth=0): - if type(arrays) is list: + if depth < max_depth: if len(arrays) == 0: raise ValueError('Lists cannot be empty') arrs = [block_recursion(arr, depth+1) for arr in arrays] @@ -437,7 +437,7 @@ def _block(arrays, max_depth, max_ndim): else: # We've 'bottomed out' - arrays is either a scalar or an array # depth == max_depth - return atleast_nd(arrays, max(max_depth, max_ndim)) + return atleast_nd(arrays, result_ndim) return block_recursion(arrays) @@ -592,4 +592,4 @@ def block(arrays): """ bottom_index, arr_ndim = _block_check_depths_match(arrays) list_ndim = len(bottom_index) - return _block(arrays, list_ndim, arr_ndim) + return _block(arrays, list_ndim, max(arr_ndim, list_ndim)) |