summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornjsmith <njs@pobox.com>2012-09-24 09:32:50 -0700
committernjsmith <njs@pobox.com>2012-09-24 09:32:50 -0700
commit6a847ef7361f4ab4910e69f20cb740d4a2eca4b6 (patch)
treed174d0d23d5941d64ad960055a8c3433f3dccbd3
parentc8010d0ebca7e0d84c653a9440faf12d26feed9e (diff)
parent69afd27b870cd85f06c4409fcffd0734ddb2fe76 (diff)
downloadnumpy-6a847ef7361f4ab4910e69f20cb740d4a2eca4b6.tar.gz
Merge pull request #440 from matthew-brett/crazy-axis-concat-warning
BUG: allow any axis for np.concatenate for 1D
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c10
-rw-r--r--numpy/core/tests/test_shape_base.py73
2 files changed, 81 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 29beb841c..a7b2ba425 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -337,6 +337,16 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis)
if (axis < 0) {
axis += ndim;
}
+
+ if (ndim == 1 & axis != 0) {
+ char msg[] = "axis != 0 for ndim == 1; this will raise an error in "
+ "future versions of numpy";
+ if (DEPRECATE(msg) < 0) {
+ return NULL;
+ }
+ axis = 0;
+ }
+
if (axis < 0 || axis >= ndim) {
PyErr_Format(PyExc_IndexError,
"axis %d out of bounds [0, %d)", orig_axis, ndim);
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index 2017ca7a3..b3f781980 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -1,7 +1,7 @@
import warnings
import numpy as np
-from numpy.testing import (TestCase, assert_, assert_raises, assert_equal,
- assert_array_equal, run_module_suite)
+from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal,
+ assert_equal, run_module_suite)
from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d,
vstack, hstack, newaxis, concatenate)
@@ -40,6 +40,7 @@ class TestAtleast1d(TestCase):
assert_(atleast_1d(3.0).shape == (1,))
assert_(atleast_1d([[2,3],[4,5]]).shape == (2,2))
+
class TestAtleast2d(TestCase):
def test_0D_array(self):
a = array(1); b = array(2);
@@ -100,6 +101,7 @@ class TestAtleast3d(TestCase):
desired = [a,b]
assert_array_equal(res,desired)
+
class TestHstack(TestCase):
def test_0D_array(self):
a = array(1); b = array(2);
@@ -119,6 +121,7 @@ class TestHstack(TestCase):
desired = array([[1,1],[2,2]])
assert_array_equal(res,desired)
+
class TestVstack(TestCase):
def test_0D_array(self):
a = array(1); b = array(2);
@@ -159,5 +162,71 @@ def test_concatenate_axis_None():
'0', '1', '2', 'x'])
assert_array_equal(r,d)
+
+def test_concatenate():
+ # Test concatenate function
+ # No arrays raise ValueError
+ assert_raises(ValueError, concatenate, ())
+ # Scalars cannot be concatenated
+ assert_raises(ValueError, concatenate, (0,))
+ assert_raises(ValueError, concatenate, (array(0),))
+ # One sequence returns unmodified (but as array)
+ r4 = list(range(4))
+ assert_array_equal(concatenate((r4,)), r4)
+ # Any sequence
+ assert_array_equal(concatenate((tuple(r4),)), r4)
+ assert_array_equal(concatenate((array(r4),)), r4)
+ # 1D default concatenation
+ r3 = list(range(3))
+ assert_array_equal(concatenate((r4, r3)), r4 + r3)
+ # Mixed sequence types
+ assert_array_equal(concatenate((tuple(r4), r3)), r4 + r3)
+ assert_array_equal(concatenate((array(r4), r3)), r4 + r3)
+ # Explicit axis specification
+ assert_array_equal(concatenate((r4, r3), 0), r4 + r3)
+ # Including negative
+ assert_array_equal(concatenate((r4, r3), -1), r4 + r3)
+ # 2D
+ a23 = array([[10, 11, 12], [13, 14, 15]])
+ a13 = array([[0, 1, 2]])
+ res = array([[10, 11, 12], [13, 14, 15], [0, 1, 2]])
+ assert_array_equal(concatenate((a23, a13)), res)
+ assert_array_equal(concatenate((a23, a13), 0), res)
+ assert_array_equal(concatenate((a23.T, a13.T), 1), res.T)
+ assert_array_equal(concatenate((a23.T, a13.T), -1), res.T)
+ # Arrays much match shape
+ assert_raises(ValueError, concatenate, (a23.T, a13.T), 0)
+ # 3D
+ res = arange(2 * 3 * 7).reshape((2, 3, 7))
+ a0 = res[..., :4]
+ a1 = res[..., 4:6]
+ a2 = res[..., 6:]
+ assert_array_equal(concatenate((a0, a1, a2), 2), res)
+ assert_array_equal(concatenate((a0, a1, a2), -1), res)
+ assert_array_equal(concatenate((a0.T, a1.T, a2.T), 0), res.T)
+
+
+def test_concatenate_sloppy0():
+ # Versions of numpy < 1.7.0 ignored axis argument value for 1D arrays. We
+ # allow this for now, but in due course we will raise an error
+ r4 = list(range(4))
+ r3 = list(range(3))
+ assert_array_equal(concatenate((r4, r3), 0), r4 + r3)
+ warnings.simplefilter('ignore', DeprecationWarning)
+ try:
+ assert_array_equal(concatenate((r4, r3), -10), r4 + r3)
+ assert_array_equal(concatenate((r4, r3), 10), r4 + r3)
+ finally:
+ warnings.filters.pop(0)
+ # Confurm DepractionWarning raised
+ warnings.simplefilter('always', DeprecationWarning)
+ warnings.simplefilter('error', DeprecationWarning)
+ try:
+ assert_raises(DeprecationWarning, concatenate, (r4, r3), 10)
+ finally:
+ warnings.filters.pop(0)
+ warnings.filters.pop(0)
+
+
if __name__ == "__main__":
run_module_suite()