summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2009-07-12 22:36:33 +0000
committerPauli Virtanen <pav@iki.fi>2009-07-12 22:36:33 +0000
commit656cca10fa0ad2f504c17d9db25e156bc84ec554 (patch)
treecb5fbd67ffd1fd6b16ee07b78f9de98e35327c9c
parent07af346db0370c2709e545fe283123d4d9b2bd97 (diff)
downloadnumpy-656cca10fa0ad2f504c17d9db25e156bc84ec554.tar.gz
Fixed #1162: make matrix_power accept lists etc. as input
-rw-r--r--numpy/core/defmatrix.py3
-rw-r--r--numpy/core/tests/test_defmatrix.py10
2 files changed, 12 insertions, 1 deletions
diff --git a/numpy/core/defmatrix.py b/numpy/core/defmatrix.py
index cbb338469..354e40060 100644
--- a/numpy/core/defmatrix.py
+++ b/numpy/core/defmatrix.py
@@ -2,7 +2,7 @@ __all__ = ['matrix', 'bmat', 'mat', 'asmatrix']
import sys
import numeric as N
-from numeric import concatenate, isscalar, binary_repr, identity
+from numeric import concatenate, isscalar, binary_repr, identity, asanyarray
from numerictypes import issubdtype
# make translation table
@@ -115,6 +115,7 @@ def matrix_power(M,n):
[ 0, -1]])
"""
+ M = asanyarray(M)
if len(M.shape) != 2 or M.shape[0] != M.shape[1]:
raise ValueError("input must be a square array")
if not issubdtype(type(n),int):
diff --git a/numpy/core/tests/test_defmatrix.py b/numpy/core/tests/test_defmatrix.py
index e9f0d9a7f..40728bd29 100644
--- a/numpy/core/tests/test_defmatrix.py
+++ b/numpy/core/tests/test_defmatrix.py
@@ -1,5 +1,6 @@
from numpy.testing import *
from numpy.core import *
+from numpy.core.defmatrix import matrix_power
import numpy as np
class TestCtor(TestCase):
@@ -358,6 +359,15 @@ class TestNewScalarIndexing(TestCase):
assert_array_equal(x[:,[1,0]],x[:,::-1])
assert_array_equal(x[[2,1,0],:],x[::-1,:])
+class TestPower(TestCase):
+ def test_returntype(self):
+ a = array([[0,1],[0,0]])
+ assert type(matrix_power(a, 2)) is ndarray
+ a = mat(a)
+ assert type(matrix_power(a, 2)) is matrix
+
+ def test_list(self):
+ assert_array_equal(matrix_power([[0, 1], [0, 0]], 2), [[0, 0], [0, 0]])
if __name__ == "__main__":
run_module_suite()