diff options
-rw-r--r-- | doc/release/upcoming_changes/19151.improvement.rst | 4 | ||||
-rw-r--r-- | numpy/linalg/linalg.py | 10 | ||||
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 3 |
3 files changed, 12 insertions, 5 deletions
diff --git a/doc/release/upcoming_changes/19151.improvement.rst b/doc/release/upcoming_changes/19151.improvement.rst index 751cd8f75..e138b39bd 100644 --- a/doc/release/upcoming_changes/19151.improvement.rst +++ b/doc/release/upcoming_changes/19151.improvement.rst @@ -1,6 +1,6 @@ -`np.linalg.qr` accepts stacked matrices as inputs +`numpy.linalg.qr` accepts stacked matrices as inputs ------------------------------------------------- -`np.linalg.qr` is able to produce results for stacked matrices as inputs. +`numpy.linalg.qr` is able to produce results for stacked matrices as inputs. Moreover, the implementation of QR decomposition has been shifted to C from Python. diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 99d77120a..4b3ae79f8 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -99,7 +99,7 @@ def _raise_linalgerror_svd_nonconvergence(err, flag): def _raise_linalgerror_lstsq(err, flag): raise LinAlgError("SVD did not converge in Linear Least Squares") -def _raise_linalgerror_qr_r_raw(err, flag): +def _raise_linalgerror_qr(err, flag): raise LinAlgError("Illegal argument found while performing " "QR factorization") @@ -934,7 +934,7 @@ def qr(a, mode='reduced'): gufunc = _umath_linalg.qr_r_raw_n signature = 'D->D' if isComplexType(t) else 'd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) + extobj = get_linalg_error_extobj(_raise_linalgerror_qr) tau = gufunc(a, signature=signature, extobj=extobj) # handle modes that don't return q @@ -956,6 +956,10 @@ def qr(a, mode='reduced'): a = a.astype(result_t, copy=False) return wrap(a) + # mc is the number of columns in the resulting q + # matrix. If the mode is complete then it is + # same as number of rows, and if the mode is reduced, + # then it is the minimum of number of rows and columns. if mode == 'complete' and m > n: mc = m if m <= n: @@ -970,7 +974,7 @@ def qr(a, mode='reduced'): gufunc = _umath_linalg.qr_reduced_n signature = 'DD->D' if isComplexType(t) else 'dd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) + extobj = get_linalg_error_extobj(_raise_linalgerror_qr) q = gufunc(a, tau, signature=signature, extobj=extobj) r = triu(a[..., :mc, :]) diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index 1a1a3babc..5b56c9565 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -4240,6 +4240,7 @@ static char lstsq_types[] = { NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, }; +/* A, tau */ static char qr_r_raw_types[] = { NPY_FLOAT, NPY_FLOAT, NPY_DOUBLE, NPY_DOUBLE, @@ -4247,6 +4248,7 @@ static char qr_r_raw_types[] = { NPY_CDOUBLE, NPY_CDOUBLE, }; +/* A, tau, q */ static char qr_reduced_types[] = { NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, @@ -4254,6 +4256,7 @@ static char qr_reduced_types[] = { NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, }; +/* A, tau, q */ static char qr_complete_types[] = { NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, |