diff options
author | Jamie Townsend <jamiehntownsend@gmail.com> | 2017-10-12 13:54:42 +0100 |
---|---|---|
committer | Jamie Townsend <jamiehntownsend@gmail.com> | 2017-10-12 13:54:42 +0100 |
commit | c2b5be5e75bd665e02c97b3d03f559dfde485fbb (patch) | |
tree | 540a7b57acf9e3212840c5cc9e58c65333e802cd /numpy/core/shape_base.py | |
parent | 5a0557ae36dbba08ce3374f56b8d1502913123e4 (diff) | |
download | numpy-c2b5be5e75bd665e02c97b3d03f559dfde485fbb.tar.gz |
Further slight simplifications
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r-- | numpy/core/shape_base.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index c65849752..732ea30ae 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -399,8 +399,8 @@ def _block_check_depths_match(arrays, parent_index=[]): ) ) elif type(arrays) is list and len(arrays) > 0: - indexes, arr_ndims = zip(*[_block_check_depths_match(arr, parent_index + [i]) - for i, arr in enumerate(arrays)]) + indexes, arr_ndims = zip(*(_block_check_depths_match(arr, parent_index + [i]) + for i, arr in enumerate(arrays))) first_index = indexes[0] for i, index in enumerate(indexes): @@ -422,14 +422,14 @@ def _block_check_depths_match(arrays, parent_index=[]): return parent_index, _nx.ndim(arrays) -def _block(arrays, max_depth, max_ndim): +def _block(arrays, max_depth, result_ndim): def atleast_nd(a, ndim): # Ensures `a` has at least `ndim` dimensions by prepending # ones to `a.shape` as necessary return array(a, ndmin=ndim, copy=False, subok=True) def block_recursion(arrays, depth=0): - if type(arrays) is list: + if depth < max_depth: if len(arrays) == 0: raise ValueError('Lists cannot be empty') arrs = [block_recursion(arr, depth+1) for arr in arrays] @@ -437,7 +437,7 @@ def _block(arrays, max_depth, max_ndim): else: # We've 'bottomed out' - arrays is either a scalar or an array # depth == max_depth - return atleast_nd(arrays, max(max_depth, max_ndim)) + return atleast_nd(arrays, result_ndim) return block_recursion(arrays) @@ -592,4 +592,4 @@ def block(arrays): """ bottom_index, arr_ndim = _block_check_depths_match(arrays) list_ndim = len(bottom_index) - return _block(arrays, list_ndim, arr_ndim) + return _block(arrays, list_ndim, max(arr_ndim, list_ndim)) |