summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/release/upcoming_changes/19151.improvement.rst4
-rw-r--r--numpy/linalg/linalg.py10
-rw-r--r--numpy/linalg/umath_linalg.c.src3
3 files changed, 12 insertions, 5 deletions
diff --git a/doc/release/upcoming_changes/19151.improvement.rst b/doc/release/upcoming_changes/19151.improvement.rst
index 751cd8f75..e138b39bd 100644
--- a/doc/release/upcoming_changes/19151.improvement.rst
+++ b/doc/release/upcoming_changes/19151.improvement.rst
@@ -1,6 +1,6 @@
-`np.linalg.qr` accepts stacked matrices as inputs
+`numpy.linalg.qr` accepts stacked matrices as inputs
-------------------------------------------------
-`np.linalg.qr` is able to produce results for stacked matrices as inputs.
+`numpy.linalg.qr` is able to produce results for stacked matrices as inputs.
Moreover, the implementation of QR decomposition has been shifted to C
from Python.
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 99d77120a..4b3ae79f8 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -99,7 +99,7 @@ def _raise_linalgerror_svd_nonconvergence(err, flag):
def _raise_linalgerror_lstsq(err, flag):
raise LinAlgError("SVD did not converge in Linear Least Squares")
-def _raise_linalgerror_qr_r_raw(err, flag):
+def _raise_linalgerror_qr(err, flag):
raise LinAlgError("Illegal argument found while performing "
"QR factorization")
@@ -934,7 +934,7 @@ def qr(a, mode='reduced'):
gufunc = _umath_linalg.qr_r_raw_n
signature = 'D->D' if isComplexType(t) else 'd->d'
- extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw)
+ extobj = get_linalg_error_extobj(_raise_linalgerror_qr)
tau = gufunc(a, signature=signature, extobj=extobj)
# handle modes that don't return q
@@ -956,6 +956,10 @@ def qr(a, mode='reduced'):
a = a.astype(result_t, copy=False)
return wrap(a)
+ # mc is the number of columns in the resulting q
+ # matrix. If the mode is complete then it is
+ # same as number of rows, and if the mode is reduced,
+ # then it is the minimum of number of rows and columns.
if mode == 'complete' and m > n:
mc = m
if m <= n:
@@ -970,7 +974,7 @@ def qr(a, mode='reduced'):
gufunc = _umath_linalg.qr_reduced_n
signature = 'DD->D' if isComplexType(t) else 'dd->d'
- extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw)
+ extobj = get_linalg_error_extobj(_raise_linalgerror_qr)
q = gufunc(a, tau, signature=signature, extobj=extobj)
r = triu(a[..., :mc, :])
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index 1a1a3babc..5b56c9565 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -4240,6 +4240,7 @@ static char lstsq_types[] = {
NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
};
+/* A, tau */
static char qr_r_raw_types[] = {
NPY_FLOAT, NPY_FLOAT,
NPY_DOUBLE, NPY_DOUBLE,
@@ -4247,6 +4248,7 @@ static char qr_r_raw_types[] = {
NPY_CDOUBLE, NPY_CDOUBLE,
};
+/* A, tau, q */
static char qr_reduced_types[] = {
NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
@@ -4254,6 +4256,7 @@ static char qr_reduced_types[] = {
NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
};
+/* A, tau, q */
static char qr_complete_types[] = {
NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,