summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-09-27 12:33:09 -0600
committerCharles Harris <charlesr.harris@gmail.com>2015-09-27 12:33:09 -0600
commitea289ee45dffa85d2fde07685eb3334816d8af7c (patch)
tree92d94ad578e6d2ba26def28837725f6cf0e5690e /numpy/lib
parent41afcc3681d250f231aea9d9f428a9e197a47f6e (diff)
parentf29c387272a9279f82ab04bbbe1bb68040b6d383 (diff)
downloadnumpy-ea289ee45dffa85d2fde07685eb3334816d8af7c.tar.gz
Merge pull request #6371 from seberg/pr-5771
BUG: Make sure warning for array split is always applied
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/shape_base.py2
-rw-r--r--numpy/lib/tests/test_shape_base.py9
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 26c2aab04..b2beef0a8 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -424,7 +424,7 @@ def array_split(ary, indices_or_sections, axis=0):
# This "kludge" was introduced here to replace arrays shaped (0, 10)
# or similar with an array shaped (0,).
# There seems no need for this, so give a FutureWarning to remove later.
- if sub_arys[-1].size == 0 and sub_arys[-1].ndim != 1:
+ if any(arr.size == 0 and arr.ndim != 1 for arr in sub_arys):
warnings.warn("in the future np.array_split will retain the shape of "
"arrays with a zero size, instead of replacing them by "
"`array([])`, which always has a shape of (0,).",
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 3f2d8d5b4..8ab72b9f9 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -111,6 +111,15 @@ class TestArraySplit(TestCase):
compare_results(res, desired)
assert_(a.dtype.type is res[-1].dtype.type)
+ # Same thing for manual splits:
+ res = assert_warns(FutureWarning, array_split, a, [0, 1, 2], axis=0)
+
+ # After removing the FutureWarning, the last should be zeros((0, 10))
+ desired = [np.array([]), np.array([np.arange(10)]),
+ np.array([np.arange(10)])]
+ compare_results(res, desired)
+ assert_(a.dtype.type is res[-1].dtype.type)
+
def test_integer_split_2D_cols(self):
a = np.array([np.arange(10), np.arange(10)])
res = array_split(a, 3, axis=-1)