diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-02-09 11:22:55 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-02-10 18:28:07 +0000 |
commit | fa040cada9e832a6d23f359a4029953d14acca0a (patch) | |
tree | a64395aaa2c25966399aee341ef60be7b65987cb | |
parent | 1b877254af0850d025cdc5d07b3fcaa1614dbe4b (diff) | |
download | numpy-fa040cada9e832a6d23f359a4029953d14acca0a.tar.gz |
BUG: make np.squeeze always return an array, never a scalar
Fixes #8588
-rw-r--r-- | doc/release/1.13.0-notes.rst | 5 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 3 | ||||
-rw-r--r-- | numpy/core/tests/test_numeric.py | 3 |
3 files changed, 9 insertions, 2 deletions
diff --git a/doc/release/1.13.0-notes.rst b/doc/release/1.13.0-notes.rst index c70568821..43bf3c7b8 100644 --- a/doc/release/1.13.0-notes.rst +++ b/doc/release/1.13.0-notes.rst @@ -76,6 +76,11 @@ obvious exception of any code that tries to directly call ``ndarray.__getslice__`` (e.g. through ``super(...).__getslice__``). In this case, ``.__getitem__(slice(start, end))`` will act as a replacement. +``np.squeeze`` always returns an array when passed a scalar +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Previously, this was only the case when passed a python scalar, and it did not +do array promotion when passed a numpy scalar. + C API ~~~~~ diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index f85e3b828..db7cad21b 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1579,8 +1579,7 @@ gentype_squeeze(PyObject *self, PyObject *args) if (!PyArg_ParseTuple(args, "")) { return NULL; } - Py_INCREF(self); - return self; + return PyArray_FromScalar(self, NULL); } static Py_ssize_t diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 4aa6bed33..919e49061 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -147,6 +147,9 @@ class TestNonarrayArgs(TestCase): A = [[[1, 1, 1], [2, 2, 2], [3, 3, 3]]] assert_(np.squeeze(A).shape == (3, 3)) + assert_(isinstance(np.squeeze(1), np.ndarray)) + assert_(isinstance(np.squeeze(np.int32(1)), np.ndarray)) + def test_std(self): A = [[1, 2, 3], [4, 5, 6]] assert_almost_equal(np.std(A), 1.707825127659933) |