summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorczgdp1807 <gdp.1807@gmail.com>2021-06-04 15:05:30 +0530
committerczgdp1807 <gdp.1807@gmail.com>2021-06-04 15:05:30 +0530
commita3b58bae0f5bb19b294669a6540e079885529fd6 (patch)
tree2f5131c68b02c3b4af8fb31d349ac927259fad38 /numpy
parentd5cbb5e0eefec9c97a901f1efa3e1a40d1ec0b04 (diff)
downloadnumpy-a3b58bae0f5bb19b294669a6540e079885529fd6.tar.gz
ready for formal addition of tests and documentation update
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py56
-rw-r--r--numpy/linalg/umath_linalg.c.src279
2 files changed, 280 insertions, 55 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 97d130894..2031014ac 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -769,12 +769,12 @@ def cholesky(a):
# QR decomposition
-def _qr_dispatcher(a, mode=None, old=None):
+def _qr_dispatcher(a, mode=None):
return (a,)
@array_function_dispatch(_qr_dispatcher)
-def qr(a, mode='reduced', old=True):
+def qr(a, mode='reduced'):
"""
Compute the qr factorization of a matrix.
@@ -925,65 +925,33 @@ def qr(a, mode='reduced', old=True):
# handle modes that don't return q
if mode == 'r':
- return wrap(triu(a.T))
+ return wrap(triu(a))
if mode == 'raw':
- return wrap(a.T), tau
+ return wrap(transpose(a)), tau
if mode == 'economic':
if t != result_t :
a = a.astype(result_t, copy=False)
return wrap(a)
- if old:
- # generate q from a
- if mode == 'complete' and m > n:
- mc = m
- q = empty((m, m), t)
- else:
- mc = mn
- q = empty((m, n), t)
- q[:n] = a
-
- if isComplexType(t):
- lapack_routine = lapack_lite.zungqr
- routine_name = 'zungqr'
+ if mode == 'complete' and m > n:
+ mc = m
+ if m <= n:
+ gufunc = _umath_linalg.qr_complete_m
else:
- lapack_routine = lapack_lite.dorgqr
- routine_name = 'dorgqr'
-
- # determine optimal lwork
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
-
- # compute q
- lwork = max(1, n, int(abs(work[0])))
- work = zeros((lwork,), t)
- results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
-
- q = _fastCopyAndTranspose(result_t, q[:mc])
- r = _fastCopyAndTranspose(result_t, a[:, :mc])
-
- return wrap(q), wrap(triu(r))
-
- if mode == 'reduced':
+ gufunc = _umath_linalg.qr_complete_n
+ else:
+ mc = mn
if m <= n:
gufunc = _umath_linalg.qr_reduced_m
else:
gufunc = _umath_linalg.qr_reduced_n
- print(gufunc)
signature = 'DD->D' if isComplexType(t) else 'dd->d'
extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw)
q = gufunc(a, tau, signature=signature, extobj=extobj)
-
- r = _fastCopyAndTranspose(result_t, a[:, :mn])
- return q, wrap(triu(r))
+ return wrap(q), wrap(triu(a[:, :mc]))
# Eigenvalues
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index 045b70003..c58fbaa42 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -3559,6 +3559,7 @@ typedef struct gqr_params_struct
fortran_int M;
fortran_int MC;
fortran_int MN;
+ void* A;
void *Q;
fortran_int LDA;
void* TAU;
@@ -3620,19 +3621,18 @@ call_@lapack_func@(GQR_PARAMS_t *params)
*/
static inline int
init_@lapack_func@(GQR_PARAMS_t *params,
- void* a,
fortran_int m,
fortran_int n)
{
- printf("%zu %zu\n", m, n);
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
- npy_uint8 *q, *tau, *work;
+ npy_uint8 *a, *q, *tau, *work;
fortran_int min_m_n = fortran_int_min(m, n);
size_t safe_min_m_n = min_m_n;
size_t safe_m = m;
size_t safe_n = n;
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@);
size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
@@ -3640,19 +3640,20 @@ init_@lapack_func@(GQR_PARAMS_t *params,
size_t work_size;
fortran_int lda = fortran_int_max(1, m);
- mem_buff = malloc(q_size + tau_size);
+ mem_buff = malloc(q_size + tau_size + a_size);
if (!mem_buff)
goto error;
q = mem_buff;
- memcpy(q, a, q_size);
tau = q + q_size;
+ a = tau + tau_size;
params->M = m;
params->MC = min_m_n;
params->MN = min_m_n;
+ params->A = a;
params->Q = q;
params->TAU = tau;
params->LDA = lda;
@@ -3702,18 +3703,18 @@ init_@lapack_func@(GQR_PARAMS_t *params,
*/
static inline int
init_@lapack_func@(GQR_PARAMS_t *params,
- void* a,
fortran_int m,
fortran_int n)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
- npy_uint8 *q, *tau, *work;
+ npy_uint8 *a, *q, *tau, *work;
fortran_int min_m_n = fortran_int_min(m, n);
size_t safe_min_m_n = min_m_n;
size_t safe_m = m;
size_t safe_n = n;
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@);
size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
@@ -3721,19 +3722,20 @@ init_@lapack_func@(GQR_PARAMS_t *params,
size_t work_size;
fortran_int lda = fortran_int_max(1, m);
- mem_buff = malloc(q_size + tau_size);
+ mem_buff = malloc(q_size + tau_size + a_size);
if (!mem_buff)
goto error;
q = mem_buff;
- memcpy(q, a, q_size);
tau = q + q_size;
+ a = tau + tau_size;
params->M = m;
params->MC = min_m_n;
params->MN = min_m_n;
+ params->A = a;
params->Q = q;
params->TAU = tau;
params->LDA = lda;
@@ -3814,14 +3816,243 @@ static void
n = (fortran_int)dimensions[1];
- if (init_@lapack_func@(&params, (void*)args[0], m, n)) {
- LINEARIZE_DATA_t tau_in, q_out;
+ if (init_@lapack_func@(&params, m, n)) {
+ LINEARIZE_DATA_t a_in, tau_in, q_out;
+ init_linearize_data(&a_in, n, m, steps[1], steps[0]);
init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]);
init_linearize_data(&q_out, fortran_int_min(m, n), m, steps[4], steps[3]);
BEGIN_OUTER_LOOP_3
int not_ok;
+ linearize_@TYPE@_matrix(params.A, args[0], &a_in);
+ size_t safe_min_m_n = fortran_int_min(m, n);
+ size_t safe_m = m;
+ memcpy(params.Q, params.A, safe_min_m_n * safe_m * sizeof(@ftyp@));
+ linearize_@TYPE@_matrix(params.TAU, args[1], &tau_in);
+ not_ok = call_@lapack_func@(&params);
+ if (!not_ok) {
+ delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_in);
+ delinearize_@TYPE@_matrix(args[2], params.Q, &q_out);
+ } else {
+ error_occurred = 1;
+ nan_@TYPE@_matrix(args[1], &tau_in);
+ nan_@TYPE@_matrix(args[2], &q_out);
+ }
+ END_OUTER_LOOP
+
+ release_@lapack_func@(&params);
+ }
+
+ set_fp_invalid_or_clear(error_occurred);
+}
+
+/**end repeat**/
+
+/* -------------------------------------------------------------------------- */
+ /* qr (modes - complete) */
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE#
+ #lapack_func=sorgqr,dorgqr#
+ #ftyp=fortran_real,fortran_doublereal#
+ */
+static inline int
+init_@lapack_func@_complete(GQR_PARAMS_t *params,
+ fortran_int m,
+ fortran_int n)
+{
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *a, *q, *tau, *work;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
+ size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@);
+ size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
+
+ fortran_int work_count;
+ size_t work_size;
+ fortran_int lda = fortran_int_max(1, m);
+
+ mem_buff = malloc(q_size + tau_size + a_size);
+
+ if (!mem_buff)
+ goto error;
+
+ q = mem_buff;
+ tau = q + q_size;
+ a = tau + tau_size;
+
+
+ params->M = m;
+ params->MC = min_m_n;
+ params->MN = min_m_n;
+ params->A = a;
+ params->Q = q;
+ params->TAU = tau;
+ params->LDA = lda;
+
+ {
+ /* compute optimal work size */
+ /* IS THE FOLLOWING METHOD CORRECT? */
+ @ftyp@ work_size_query;
+
+ params->WORK = &work_size_query;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query;
+
+ work_size = (size_t) work_size_query * sizeof(@ftyp@);
+ }
+
+ mem_buff2 = malloc(work_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+ memset(work, 0, work_size);
+
+ params->WORK = work;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=CFLOAT,CDOUBLE#
+ #lapack_func=cungqr,zungqr#
+ #ftyp=fortran_complex,fortran_doublecomplex#
+ */
+static inline int
+init_@lapack_func@_complete(GQR_PARAMS_t *params,
+ fortran_int m,
+ fortran_int n)
+{
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *a, *q, *tau, *work;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
+ size_t q_size = safe_m * safe_m * sizeof(@ftyp@);
+ size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
+
+ fortran_int work_count;
+ size_t work_size;
+ fortran_int lda = fortran_int_max(1, m);
+
+ mem_buff = malloc(q_size + tau_size + a_size);
+
+ if (!mem_buff)
+ goto error;
+
+ q = mem_buff;
+ tau = q + q_size;
+ a = tau + tau_size;
+
+
+ params->M = m;
+ params->MC = m;
+ params->MN = min_m_n;
+ params->A = a;
+ params->Q = q;
+ params->TAU = tau;
+ params->LDA = lda;
+
+ {
+ /* compute optimal work size */
+ /* IS THE FOLLOWING METHOD CORRECT? */
+ @ftyp@ work_size_query;
+
+ params->WORK = &work_size_query;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query.r;
+
+ work_size = (size_t) work_size_query.r * sizeof(@ftyp@);
+ }
+
+ mem_buff2 = malloc(work_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+ memset(work, 0, work_size);
+
+ params->WORK = work;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
+ #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
+ #lapack_func=sorgqr,dorgqr,cungqr,zungqr#
+ #typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
+ #basetyp = npy_float, npy_double, npy_float, npy_double#
+ #ftyp = fortran_real, fortran_doublereal,
+ fortran_complex, fortran_doublecomplex#
+ #cmplx = 0, 0, 1, 1#
+ */
+static void
+@TYPE@_qr_complete(char **args, npy_intp const *dimensions, npy_intp const *steps,
+ void *NPY_UNUSED(func))
+{
+ GQR_PARAMS_t params;
+ int error_occurred = get_fp_invalid_and_clear();
+ fortran_int n, m;
+
+ INIT_OUTER_LOOP_3
+
+ m = (fortran_int)dimensions[0];
+ n = (fortran_int)dimensions[1];
+
+
+ if (init_@lapack_func@_complete(&params, m, n)) {
+ LINEARIZE_DATA_t a_in, tau_in, q_out;
+
+ init_linearize_data(&a_in, n, m, steps[1], steps[0]);
+ init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]);
+ init_linearize_data(&q_out, m, m, steps[4], steps[3]);
+
+ BEGIN_OUTER_LOOP_3
+ int not_ok;
+ linearize_@TYPE@_matrix(params.A, args[0], &a_in);
+ size_t safe_n = n;
+ size_t safe_m = m;
+ memcpy(params.Q, params.A, safe_n * safe_m * sizeof(@ftyp@));
linearize_@TYPE@_matrix(params.TAU, args[1], &tau_in);
not_ok = call_@lapack_func@(&params);
if (!not_ok) {
@@ -3915,6 +4146,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_reduced);
+GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_complete);
GUFUNC_FUNC_ARRAY_EIG(eig);
GUFUNC_FUNC_ARRAY_EIG(eigvals);
@@ -4002,6 +4234,13 @@ static char qr_reduced_types[] = {
NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
};
+static char qr_complete_types[] = {
+ NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
+ NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
+ NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT,
+ NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
+};
+
typedef struct gufunc_descriptor_struct {
char *name;
char *signature;
@@ -4246,6 +4485,24 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = {
4, 2, 1,
FUNC_ARRAY_NAME(qr_reduced),
qr_reduced_types
+ },
+ {
+ "qr_complete_m",
+ "(m,n),(m)->(m,m)",
+ "Compute Q matrix for the last two dimensions \n"\
+ "and broadcast to the rest. For m <= n. \n",
+ 4, 2, 1,
+ FUNC_ARRAY_NAME(qr_complete),
+ qr_complete_types
+ },
+ {
+ "qr_complete_n",
+ "(m,n),(n)->(m,m)",
+ "Compute Q matrix for the last two dimensions \n"\
+ "and broadcast to the rest. For m > n. \n",
+ 4, 2, 1,
+ FUNC_ARRAY_NAME(qr_complete),
+ qr_complete_types
}
};