diff options
author | Pauli Virtanen <pav@iki.fi> | 2009-07-12 22:36:33 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2009-07-12 22:36:33 +0000 |
commit | 656cca10fa0ad2f504c17d9db25e156bc84ec554 (patch) | |
tree | cb5fbd67ffd1fd6b16ee07b78f9de98e35327c9c | |
parent | 07af346db0370c2709e545fe283123d4d9b2bd97 (diff) | |
download | numpy-656cca10fa0ad2f504c17d9db25e156bc84ec554.tar.gz |
Fixed #1162: make matrix_power accept lists etc. as input
-rw-r--r-- | numpy/core/defmatrix.py | 3 | ||||
-rw-r--r-- | numpy/core/tests/test_defmatrix.py | 10 |
2 files changed, 12 insertions, 1 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py index cbb338469..354e40060 100644 --- a/numpy/core/defmatrix.py +++ b/numpy/core/defmatrix.py @@ -2,7 +2,7 @@ __all__ = ['matrix', 'bmat', 'mat', 'asmatrix'] import sys import numeric as N -from numeric import concatenate, isscalar, binary_repr, identity +from numeric import concatenate, isscalar, binary_repr, identity, asanyarray from numerictypes import issubdtype # make translation table @@ -115,6 +115,7 @@ def matrix_power(M,n): [ 0, -1]]) """ + M = asanyarray(M) if len(M.shape) != 2 or M.shape[0] != M.shape[1]: raise ValueError("input must be a square array") if not issubdtype(type(n),int): diff --git a/numpy/core/tests/test_defmatrix.py b/numpy/core/tests/test_defmatrix.py index e9f0d9a7f..40728bd29 100644 --- a/numpy/core/tests/test_defmatrix.py +++ b/numpy/core/tests/test_defmatrix.py @@ -1,5 +1,6 @@ from numpy.testing import * from numpy.core import * +from numpy.core.defmatrix import matrix_power import numpy as np class TestCtor(TestCase): @@ -358,6 +359,15 @@ class TestNewScalarIndexing(TestCase): assert_array_equal(x[:,[1,0]],x[:,::-1]) assert_array_equal(x[[2,1,0],:],x[::-1,:]) +class TestPower(TestCase): + def test_returntype(self): + a = array([[0,1],[0,0]]) + assert type(matrix_power(a, 2)) is ndarray + a = mat(a) + assert type(matrix_power(a, 2)) is matrix + + def test_list(self): + assert_array_equal(matrix_power([[0, 1], [0, 0]], 2), [[0, 0], [0, 0]]) if __name__ == "__main__": run_module_suite() |