summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/linalg/linalg.py31
-rw-r--r--numpy/linalg/umath_linalg.c.src326
2 files changed, 335 insertions, 22 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 46fb2502e..087156bd0 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -904,34 +904,21 @@ def qr(a, mode='reduced'):
raise ValueError(f"Unrecognized mode '{mode}'")
a, wrap = _makearray(a)
- _assert_2d(a)
- m, n = a.shape
+ _assert_stacked_2d(a)
+ m, n = a.shape[-2:]
t, result_t = _commonType(a)
- a = _fastCopyAndTranspose(t, a)
+ a = a.astype(t, copy=True)
a = _to_native_byte_order(a)
mn = min(m, n)
- tau = zeros((mn,), t)
- if isComplexType(t):
- lapack_routine = lapack_lite.zgeqrf
- routine_name = 'zgeqrf'
+ if m <= n:
+ gufunc = _umath_linalg.qr_r_raw_e_m
else:
- lapack_routine = lapack_lite.dgeqrf
- routine_name = 'dgeqrf'
+ gufunc = _umath_linalg.qr_r_raw_e_n
- # calculate optimal size of work data 'work'
- lwork = 1
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
-
- # do qr decomposition
- lwork = max(1, n, int(abs(work[0])))
- work = zeros((lwork,), t)
- results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0)
- if results['info'] != 0:
- raise LinAlgError('%s returns %d' % (routine_name, results['info']))
+ signature = 'D->D' if isComplexType(t) else 'd->d'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
+ tau = gufunc(a, signature=signature, extobj=extobj)
# handle modes that don't return q
if mode == 'r':
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index 1807aadcf..a2cc86dc4 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -162,6 +162,23 @@ FNAME(zgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs,
double rwork[], fortran_int iwork[],
fortran_int *info);
+extern fortran_int
+FNAME(sgeqrf)(fortran_int *m, fortran_int *n, float a[], fortran_int *lda,
+ float tau[], float work[],
+ fortran_int *lwork, fortran_int *info);
+extern fortran_int
+FNAME(cgeqrf)(fortran_int *m, fortran_int *n, f2c_complex a[], fortran_int *lda,
+ f2c_complex tau[], f2c_complex work[],
+ fortran_int *lwork, fortran_int *info);
+extern fortran_int
+FNAME(dgeqrf)(fortran_int *m, fortran_int *n, double a[], fortran_int *lda,
+ double tau[], double work[],
+ fortran_int *lwork, fortran_int *info);
+extern fortran_int
+FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int *lda,
+ f2c_doublecomplex tau[], f2c_doublecomplex work[],
+ fortran_int *lwork, fortran_int *info);
+
extern fortran_int
FNAME(sgesv)(fortran_int *n, fortran_int *nrhs,
float a[], fortran_int *lda,
@@ -3235,6 +3252,289 @@ static void
/**end repeat**/
+/* -------------------------------------------------------------------------- */
+ /* qr (modes - r, raw, economic) */
+
+typedef struct geqfr_params_struct
+{
+ fortran_int M;
+ fortran_int N;
+ void *A;
+ fortran_int LDA;
+ void* TAU;
+ void *WORK;
+ fortran_int LWORK;
+} GEQRF_PARAMS_t;
+
+
+static inline void
+dump_geqrf_params(const char *name,
+ GEQRF_PARAMS_t *params)
+{
+ TRACE_TXT("\n%s:\n"\
+
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"
+ "%14s: %18d\n",
+
+ name,
+
+ "A", params->A,
+ "TAU", params->TAU,
+ "WORK", params->WORK,
+
+ "M", (int)params->M,
+ "N", (int)params->N,
+ "LDA", (int)params->LDA,
+ "LWORK", (int)params->LWORK);
+}
+
+/**begin repeat
+ #lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf#
+ */
+
+static inline fortran_int
+call_@lapack_func@(GEQRF_PARAMS_t *params)
+{
+ fortran_int rv;
+ LAPACK(@lapack_func@)(&params->M, &params->N,
+ params->A, &params->LDA,
+ params->TAU,
+ params->WORK, &params->LWORK,
+ &rv);
+ return rv;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE#
+ #lapack_func=sgeqrf,dgeqrf#
+ #ftyp=fortran_real,fortran_doublereal#
+ */
+static inline int
+init_@lapack_func@(GEQRF_PARAMS_t *params,
+ fortran_int m,
+ fortran_int n)
+{
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *a, *tau, *work;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
+ size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
+
+ fortran_int work_count;
+ size_t work_size;
+ fortran_int lda = fortran_int_max(1, m);
+
+ mem_buff = malloc(a_size + tau_size);
+
+ if (!mem_buff)
+ goto error;
+
+ a = mem_buff;
+ tau = a + a_size;
+ memset(tau, 0, tau_size);
+
+
+ params->M = m;
+ params->N = n;
+ params->A = a;
+ params->TAU = tau;
+ params->LDA = lda;
+
+ {
+ /* compute optimal work size */
+ /* IS THE FOLLOWING METHOD CORRECT? */
+ @ftyp@ work_size_query;
+
+ params->WORK = &work_size_query;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query;
+
+ work_size = (size_t) work_size_query * sizeof(@ftyp@);
+ }
+
+ mem_buff2 = malloc(work_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+ memset(work, 0, work_size);
+
+ params->WORK = work;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=CFLOAT,CDOUBLE#
+ #lapack_func=cgeqrf,zgeqrf#
+ #ftyp=fortran_complex,fortran_doublecomplex#
+ */
+static inline int
+init_@lapack_func@(GEQRF_PARAMS_t *params,
+ fortran_int m,
+ fortran_int n)
+{
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *a, *tau, *work;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
+ size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
+
+ fortran_int work_count;
+ size_t work_size;
+ fortran_int lda = fortran_int_max(1, m);
+
+ mem_buff = malloc(a_size + tau_size);
+
+ if (!mem_buff)
+ goto error;
+
+ a = mem_buff;
+ tau = a + tau_size;
+ memset(tau, 0, tau_size);
+
+
+ params->M = m;
+ params->N = n;
+ params->A = a;
+ params->TAU = tau;
+ params->LDA = lda;
+
+ {
+ /* compute optimal work size */
+ /* IS THE FOLLOWING METHOD CORRECT? */
+ @ftyp@ work_size_query;
+
+ params->WORK = &work_size_query;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query.r;
+
+ work_size = (size_t) work_size_query.r * sizeof(@ftyp@);
+ }
+
+ mem_buff2 = malloc(work_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+
+ params->WORK = work;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf#
+ */
+static inline void
+release_@lapack_func@(GEQRF_PARAMS_t* params)
+{
+ /* A and WORK contain allocated blocks */
+ free(params->A);
+ free(params->WORK);
+ memset(params, 0, sizeof(*params));
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
+ #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
+ #lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf#
+ #dot_func=sdot,ddot,cdotc,zdotc#
+ #typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
+ #basetyp = npy_float, npy_double, npy_float, npy_double#
+ #ftyp = fortran_real, fortran_doublereal,
+ fortran_complex, fortran_doublecomplex#
+ #cmplx = 0, 0, 1, 1#
+ */
+static void
+@TYPE@_qr_r_raw_e(char **args, npy_intp const *dimensions, npy_intp const *steps,
+ void *NPY_UNUSED(func))
+{
+ GEQRF_PARAMS_t params;
+ int error_occurred = get_fp_invalid_and_clear();
+ fortran_int n, m;
+
+ INIT_OUTER_LOOP_2
+
+ m = (fortran_int)dimensions[0];
+ n = (fortran_int)dimensions[1];
+
+ if (init_@lapack_func@(&params, m, n)) {
+ LINEARIZE_DATA_t a_in, a_out, tau_out;
+
+ init_linearize_data(&a_in, n, m, steps[1], steps[0]);
+ init_linearize_data(&a_out, m, n, steps[0], steps[1]);
+ init_linearize_data(&tau_out, 1, fortran_int_min(m, n), 1, steps[2]);
+
+ BEGIN_OUTER_LOOP_2
+ int not_ok;
+ linearize_@TYPE@_matrix(params.A, args[0], &a_in);
+ not_ok = call_@lapack_func@(&params);
+ if (!not_ok) {
+ delinearize_@TYPE@_matrix(args[0], params.A, &a_out);
+ delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_out);
+ } else {
+ error_occurred = 1;
+ nan_@TYPE@_matrix(args[0], &a_in);
+ nan_@TYPE@_matrix(args[1], &tau_out);
+ }
+ END_OUTER_LOOP
+
+ release_@lapack_func@(&params);
+ }
+
+ set_fp_invalid_or_clear(error_occurred);
+}
+
+/**end repeat**/
+
#pragma GCC diagnostic pop
/* -------------------------------------------------------------------------- */
@@ -3306,6 +3606,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq);
+GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw_e);
GUFUNC_FUNC_ARRAY_EIG(eig);
GUFUNC_FUNC_ARRAY_EIG(eigvals);
@@ -3379,6 +3680,13 @@ static char lstsq_types[] = {
NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
};
+static char qr_r_raw_e_types[] = {
+ NPY_FLOAT, NPY_FLOAT,
+ NPY_DOUBLE, NPY_DOUBLE,
+ NPY_CFLOAT, NPY_CFLOAT,
+ NPY_CDOUBLE, NPY_CDOUBLE,
+};
+
typedef struct gufunc_descriptor_struct {
char *name;
char *signature;
@@ -3587,6 +3895,24 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = {
4, 3, 4,
FUNC_ARRAY_NAME(lstsq),
lstsq_types
+ },
+ {
+ "qr_r_raw_e_m",
+ "(m,n)->(m)",
+ "Compute TAU vector for the last two dimensions \n"\
+ "and broadcast to the rest. For m <= n. \n",
+ 4, 1, 1,
+ FUNC_ARRAY_NAME(qr_r_raw_e),
+ qr_r_raw_e_types
+ },
+ {
+ "qr_r_raw_e_n",
+ "(m,n)->(n)",
+ "Compute TAU vector for the last two dimensions \n"\
+ "and broadcast to the rest. For m > n. \n",
+ 4, 1, 1,
+ FUNC_ARRAY_NAME(qr_r_raw_e),
+ qr_r_raw_e_types
}
};