summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorGarrett-R <garrettreynolds5@gmail.com>2014-12-08 20:33:48 -0800
committerGarrett-R <garrettreynolds5@gmail.com>2014-12-08 20:33:48 -0800
commit819b3a8a019469774a5343afd87ec71ec696bf80 (patch)
treec4bf6458781941e24b489b2066fbf755f8deeb2a /numpy
parentb8a5da49675009165326ec2e7aa6968cf6e15782 (diff)
downloadnumpy-819b3a8a019469774a5343afd87ec71ec696bf80.tar.gz
BUG: Closes #2015: diag returns ndarray
If x is a matrix, np.diag(x) and np.diagonal(x) now return matrices instead of arrays. Both of these cause x.diagonal() to be called. That means they return row vectors (just like x.flatten(), x.ravel(), x.cumprod(), etc.)
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/fromnumeric.py2
-rw-r--r--numpy/lib/twodim_base.py2
-rw-r--r--numpy/matrixlib/tests/test_numeric.py12
3 files changed, 13 insertions, 3 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index 84a10bf04..93ee07caa 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -1268,7 +1268,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
[5, 7]])
"""
- return asarray(a).diagonal(offset, axis1, axis2)
+ return asanyarray(a).diagonal(offset, axis1, axis2)
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py
index 40a140b6b..0c5065fa1 100644
--- a/numpy/lib/twodim_base.py
+++ b/numpy/lib/twodim_base.py
@@ -293,7 +293,7 @@ def diag(v, k=0):
[0, 0, 8]])
"""
- v = asarray(v)
+ v = asanyarray(v)
s = v.shape
if len(s) == 1:
n = s[0]+abs(k)
diff --git a/numpy/matrixlib/tests/test_numeric.py b/numpy/matrixlib/tests/test_numeric.py
index 3588de5e6..91dc92d2e 100644
--- a/numpy/matrixlib/tests/test_numeric.py
+++ b/numpy/matrixlib/tests/test_numeric.py
@@ -2,12 +2,22 @@ from __future__ import division, absolute_import, print_function
from numpy.testing import assert_equal, TestCase, run_module_suite
from numpy.core import ones
-from numpy import matrix
+from numpy import matrix, diagonal, diag
class TestDot(TestCase):
def test_matscalar(self):
b1 = matrix(ones((3, 3), dtype=complex))
assert_equal(b1*1.0, b1)
+
+def test_diagonal():
+ b1 = matrix([[1,2],[3,4]])
+ diag_b1 = matrix([[1, 4]])
+
+ assert_equal(b1.diagonal(), diag_b1)
+ assert_equal(diagonal(b1), diag_b1)
+ assert_equal(diag(b1), diag_b1)
+
+
if __name__ == "__main__":
run_module_suite()