diff options
author | Pauli Virtanen <pav@iki.fi> | 2013-04-12 18:32:52 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2013-04-12 19:00:06 +0300 |
commit | aa8fde0f62a133319cfac8e8da208fcd8e224ef1 (patch) | |
tree | 2d9883ec144b252ee2aa7d98251c77f62beae182 | |
parent | 1b3834d7b59da2e809d320992c632697355d63b6 (diff) | |
download | numpy-aa8fde0f62a133319cfac8e8da208fcd8e224ef1.tar.gz |
ENH: linalg: use signature= for internal casting rather than astype in linalg ufuncs
-rw-r--r-- | numpy/linalg/linalg.py | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index b0d43a531..ae0da3685 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -158,6 +158,7 @@ def _commonType(*arrays): t = double return t, result_type + # _fastCopyAndTranpose assumes the input is 2D (as all the calls in here are). _fastCT = fastCopyAndTranspose @@ -359,8 +360,9 @@ def solve(a, b): else: gufunc = _umath_linalg.solve + signature = 'DD->D' if isComplexType(t) else 'dd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_singular) - r = gufunc(a.astype(t), b.astype(t), extobj=extobj) + r = gufunc(a, b, signature=signature, extobj=extobj) return wrap(r.astype(result_t)) @@ -493,8 +495,9 @@ def inv(a): _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) + signature = 'D->D' if isComplexType(t) else 'd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_singular) - ainv = _umath_linalg.inv(a.astype(t), extobj=extobj) + ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj) return wrap(ainv.astype(result_t)) @@ -576,7 +579,8 @@ def cholesky(a): _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) - return wrap(gufunc(a.astype(t), extobj=extobj).astype(result_t)) + signature = 'D->D' if isComplexType(t) else 'd->d' + return wrap(gufunc(a, signature=signature, extobj=extobj).astype(result_t)) # QR decompostion @@ -866,8 +870,8 @@ def eigvals(a): extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) - - w = _umath_linalg.eigvals(a.astype(t), extobj=extobj) + signature = 'D->D' if isComplexType(t) else 'd->D' + w = _umath_linalg.eigvals(a, signature=signature, extobj=extobj) if not isComplexType(t): if all(w.imag == 0): @@ -939,7 +943,8 @@ def eigvalsh(a, UPLO='L'): _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) - w = gufunc(a.astype(t), extobj=extobj) + signature = 'D->d' if isComplexType(t) else 'd->d' + w = gufunc(a, signature=signature, extobj=extobj) return w.astype(result_t) def _convertarray(a): @@ -1071,7 +1076,8 @@ def eig(a): extobj = get_linalg_error_extobj( _raise_linalgerror_eigenvalues_nonconvergence) - w, vt = _umath_linalg.eig(a.astype(t), extobj=extobj) + signature = 'D->DD' if isComplexType(t) else 'd->DD' + w, vt = _umath_linalg.eig(a, signature=signature, extobj=extobj) if not isComplexType(t) and all(w.imag == 0.0): w = w.real @@ -1182,7 +1188,8 @@ def eigh(a, UPLO='L'): else: gufunc = _umath_linalg.eigh_up - w, vt = gufunc(a.astype(t), extobj=extobj) + signature = 'D->dD' if isComplexType(t) else 'd->dd' + w, vt = gufunc(a, signature=signature, extobj=extobj) w = w.astype(_realType(result_t)) vt = vt.astype(result_t) return w, wrap(vt) @@ -1291,7 +1298,8 @@ def svd(a, full_matrices=1, compute_uv=1): else: gufunc = _umath_linalg.svd_n_s - u, s, vt = gufunc(a.astype(t), extobj=extobj) + signature = 'D->DdD' if isComplexType(t) else 'd->ddd' + u, s, vt = gufunc(a, signature=signature, extobj=extobj) u = u.astype(result_t) s = s.astype(_realType(result_t)) vt = vt.astype(result_t) @@ -1302,7 +1310,8 @@ def svd(a, full_matrices=1, compute_uv=1): else: gufunc = _umath_linalg.svd_n - s = gufunc(a.astype(t), extobj=extobj) + signature = 'D->d' if isComplexType(t) else 'd->d' + s = gufunc(a, signature=signature, extobj=extobj) s = s.astype(_realType(result_t)) return s @@ -1638,7 +1647,8 @@ def slogdet(a): _assertNdSquareness(a) t, result_t = _commonType(a) real_t = _realType(result_t) - sign, logdet = _umath_linalg.slogdet(a.astype(t)) + signature = 'D->Dd' if isComplexType(t) else 'd->dd' + sign, logdet = _umath_linalg.slogdet(a, signature=signature) return sign.astype(result_t), logdet.astype(real_t) def det(a): @@ -1690,7 +1700,8 @@ def det(a): _assertRankAtLeast2(a) _assertNdSquareness(a) t, result_t = _commonType(a) - return _umath_linalg.det(a.astype(t)).astype(result_t) + signature = 'D->D' if isComplexType(t) else 'd->d' + return _umath_linalg.det(a, signature=signature).astype(result_t) # Linear Least Squares |