diff options
-rw-r--r-- | numpy/lib/shape_base.py | 4 | ||||
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 14 |
2 files changed, 16 insertions, 2 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index d31d8a939..66f534734 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -688,7 +688,7 @@ def array_split(ary, indices_or_sections, axis=0): except AttributeError: Ntotal = len(ary) try: - # handle scalar case. + # handle array case. Nsections = len(indices_or_sections) + 1 div_points = [0] + list(indices_or_sections) + [Ntotal] except TypeError: @@ -700,7 +700,7 @@ def array_split(ary, indices_or_sections, axis=0): section_sizes = ([0] + extras * [Neach_section+1] + (Nsections-extras) * [Neach_section]) - div_points = _nx.array(section_sizes).cumsum() + div_points = _nx.array(section_sizes, dtype=_nx.intp).cumsum() sub_arys = [] sary = _nx.swapaxes(ary, axis, 0) diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 6d24dd624..6e4cd225d 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -3,6 +3,8 @@ from __future__ import division, absolute_import, print_function import numpy as np import warnings import functools +import sys +import pytest from numpy.lib.shape_base import ( apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit, @@ -14,6 +16,9 @@ from numpy.testing import ( ) +IS_64BIT = sys.maxsize > 2**32 + + def _add_keepdims(func): """ hack in keepdims behavior into a function taking an axis """ @functools.wraps(func) @@ -403,6 +408,15 @@ class TestArraySplit(object): assert_(a.dtype.type is res[-1].dtype.type) # perhaps should check higher dimensions + @pytest.mark.skipif(not IS_64BIT, reason="Needs 64bit platform") + def test_integer_split_2D_rows_greater_max_int32(self): + a = np.broadcast_to([0], (1 << 32, 2)) + res = array_split(a, 4) + chunk = np.broadcast_to([0], (1 << 30, 2)) + tgt = [chunk] * 4 + for i in range(len(tgt)): + assert_equal(res[i].shape, tgt[i].shape) + def test_index_split_simple(self): a = np.arange(10) indices = [1, 5, 7] |