summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2013-04-12 18:32:52 +0300
committerPauli Virtanen <pav@iki.fi>2013-04-12 19:00:06 +0300
commitaa8fde0f62a133319cfac8e8da208fcd8e224ef1 (patch)
tree2d9883ec144b252ee2aa7d98251c77f62beae182
parent1b3834d7b59da2e809d320992c632697355d63b6 (diff)
downloadnumpy-aa8fde0f62a133319cfac8e8da208fcd8e224ef1.tar.gz
ENH: linalg: use signature= for internal casting rather than astype in linalg ufuncs
-rw-r--r--numpy/linalg/linalg.py35
1 files changed, 23 insertions, 12 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index b0d43a531..ae0da3685 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -158,6 +158,7 @@ def _commonType(*arrays):
t = double
return t, result_type
+
# _fastCopyAndTranpose assumes the input is 2D (as all the calls in here are).
_fastCT = fastCopyAndTranspose
@@ -359,8 +360,9 @@ def solve(a, b):
else:
gufunc = _umath_linalg.solve
+ signature = 'DD->D' if isComplexType(t) else 'dd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
- r = gufunc(a.astype(t), b.astype(t), extobj=extobj)
+ r = gufunc(a, b, signature=signature, extobj=extobj)
return wrap(r.astype(result_t))
@@ -493,8 +495,9 @@ def inv(a):
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
+ signature = 'D->D' if isComplexType(t) else 'd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
- ainv = _umath_linalg.inv(a.astype(t), extobj=extobj)
+ ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
return wrap(ainv.astype(result_t))
@@ -576,7 +579,8 @@ def cholesky(a):
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
- return wrap(gufunc(a.astype(t), extobj=extobj).astype(result_t))
+ signature = 'D->D' if isComplexType(t) else 'd->d'
+ return wrap(gufunc(a, signature=signature, extobj=extobj).astype(result_t))
# QR decompostion
@@ -866,8 +870,8 @@ def eigvals(a):
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
-
- w = _umath_linalg.eigvals(a.astype(t), extobj=extobj)
+ signature = 'D->D' if isComplexType(t) else 'd->D'
+ w = _umath_linalg.eigvals(a, signature=signature, extobj=extobj)
if not isComplexType(t):
if all(w.imag == 0):
@@ -939,7 +943,8 @@ def eigvalsh(a, UPLO='L'):
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
- w = gufunc(a.astype(t), extobj=extobj)
+ signature = 'D->d' if isComplexType(t) else 'd->d'
+ w = gufunc(a, signature=signature, extobj=extobj)
return w.astype(result_t)
def _convertarray(a):
@@ -1071,7 +1076,8 @@ def eig(a):
extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
- w, vt = _umath_linalg.eig(a.astype(t), extobj=extobj)
+ signature = 'D->DD' if isComplexType(t) else 'd->DD'
+ w, vt = _umath_linalg.eig(a, signature=signature, extobj=extobj)
if not isComplexType(t) and all(w.imag == 0.0):
w = w.real
@@ -1182,7 +1188,8 @@ def eigh(a, UPLO='L'):
else:
gufunc = _umath_linalg.eigh_up
- w, vt = gufunc(a.astype(t), extobj=extobj)
+ signature = 'D->dD' if isComplexType(t) else 'd->dd'
+ w, vt = gufunc(a, signature=signature, extobj=extobj)
w = w.astype(_realType(result_t))
vt = vt.astype(result_t)
return w, wrap(vt)
@@ -1291,7 +1298,8 @@ def svd(a, full_matrices=1, compute_uv=1):
else:
gufunc = _umath_linalg.svd_n_s
- u, s, vt = gufunc(a.astype(t), extobj=extobj)
+ signature = 'D->DdD' if isComplexType(t) else 'd->ddd'
+ u, s, vt = gufunc(a, signature=signature, extobj=extobj)
u = u.astype(result_t)
s = s.astype(_realType(result_t))
vt = vt.astype(result_t)
@@ -1302,7 +1310,8 @@ def svd(a, full_matrices=1, compute_uv=1):
else:
gufunc = _umath_linalg.svd_n
- s = gufunc(a.astype(t), extobj=extobj)
+ signature = 'D->d' if isComplexType(t) else 'd->d'
+ s = gufunc(a, signature=signature, extobj=extobj)
s = s.astype(_realType(result_t))
return s
@@ -1638,7 +1647,8 @@ def slogdet(a):
_assertNdSquareness(a)
t, result_t = _commonType(a)
real_t = _realType(result_t)
- sign, logdet = _umath_linalg.slogdet(a.astype(t))
+ signature = 'D->Dd' if isComplexType(t) else 'd->dd'
+ sign, logdet = _umath_linalg.slogdet(a, signature=signature)
return sign.astype(result_t), logdet.astype(real_t)
def det(a):
@@ -1690,7 +1700,8 @@ def det(a):
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
- return _umath_linalg.det(a.astype(t)).astype(result_t)
+ signature = 'D->D' if isComplexType(t) else 'd->d'
+ return _umath_linalg.det(a, signature=signature).astype(result_t)
# Linear Least Squares