diff options
author | czgdp1807 <gdp.1807@gmail.com> | 2021-06-03 15:02:26 +0530 |
---|---|---|
committer | czgdp1807 <gdp.1807@gmail.com> | 2021-06-03 15:02:26 +0530 |
commit | c25c09b0688bcda1da148fbcac47adab697c409c (patch) | |
tree | 5ba18bcd6215bf09bc2740d359e00d8c9f0418e5 /numpy | |
parent | 6790873334b143117f4e8d1f515def8c7fdeb9fb (diff) | |
download | numpy-c25c09b0688bcda1da148fbcac47adab697c409c.tar.gz |
ENH: r, raw, economic now accept stacked matrices
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 31 | ||||
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 326 |
2 files changed, 335 insertions, 22 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 46fb2502e..087156bd0 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -904,34 +904,21 @@ def qr(a, mode='reduced'): raise ValueError(f"Unrecognized mode '{mode}'") a, wrap = _makearray(a) - _assert_2d(a) - m, n = a.shape + _assert_stacked_2d(a) + m, n = a.shape[-2:] t, result_t = _commonType(a) - a = _fastCopyAndTranspose(t, a) + a = a.astype(t, copy=True) a = _to_native_byte_order(a) mn = min(m, n) - tau = zeros((mn,), t) - if isComplexType(t): - lapack_routine = lapack_lite.zgeqrf - routine_name = 'zgeqrf' + if m <= n: + gufunc = _umath_linalg.qr_r_raw_e_m else: - lapack_routine = lapack_lite.dgeqrf - routine_name = 'dgeqrf' + gufunc = _umath_linalg.qr_r_raw_e_n - # calculate optimal size of work data 'work' - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, n, a, max(1, m), tau, work, -1, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - # do qr decomposition - lwork = max(1, n, int(abs(work[0]))) - work = zeros((lwork,), t) - results = lapack_routine(m, n, a, max(1, m), tau, work, lwork, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) + signature = 'D->D' if isComplexType(t) else 'd->d' + extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) + tau = gufunc(a, signature=signature, extobj=extobj) # handle modes that don't return q if mode == 'r': diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index 1807aadcf..a2cc86dc4 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -162,6 +162,23 @@ FNAME(zgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, double rwork[], fortran_int iwork[], fortran_int *info); +extern fortran_int +FNAME(sgeqrf)(fortran_int *m, fortran_int *n, float a[], fortran_int *lda, + float tau[], float work[], + fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(cgeqrf)(fortran_int *m, fortran_int *n, f2c_complex a[], fortran_int *lda, + f2c_complex tau[], f2c_complex work[], + fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(dgeqrf)(fortran_int *m, fortran_int *n, double a[], fortran_int *lda, + double tau[], double work[], + fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int *lda, + f2c_doublecomplex tau[], f2c_doublecomplex work[], + fortran_int *lwork, fortran_int *info); + extern fortran_int FNAME(sgesv)(fortran_int *n, fortran_int *nrhs, float a[], fortran_int *lda, @@ -3235,6 +3252,289 @@ static void /**end repeat**/ +/* -------------------------------------------------------------------------- */ + /* qr (modes - r, raw, economic) */ + +typedef struct geqfr_params_struct +{ + fortran_int M; + fortran_int N; + void *A; + fortran_int LDA; + void* TAU; + void *WORK; + fortran_int LWORK; +} GEQRF_PARAMS_t; + + +static inline void +dump_geqrf_params(const char *name, + GEQRF_PARAMS_t *params) +{ + TRACE_TXT("\n%s:\n"\ + + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n" + "%14s: %18d\n", + + name, + + "A", params->A, + "TAU", params->TAU, + "WORK", params->WORK, + + "M", (int)params->M, + "N", (int)params->N, + "LDA", (int)params->LDA, + "LWORK", (int)params->LWORK); +} + +/**begin repeat + #lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf# + */ + +static inline fortran_int +call_@lapack_func@(GEQRF_PARAMS_t *params) +{ + fortran_int rv; + LAPACK(@lapack_func@)(¶ms->M, ¶ms->N, + params->A, ¶ms->LDA, + params->TAU, + params->WORK, ¶ms->LWORK, + &rv); + return rv; +} + +/**end repeat**/ + +/**begin repeat + #TYPE=FLOAT,DOUBLE# + #lapack_func=sgeqrf,dgeqrf# + #ftyp=fortran_real,fortran_doublereal# + */ +static inline int +init_@lapack_func@(GEQRF_PARAMS_t *params, + fortran_int m, + fortran_int n) +{ + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *a, *tau, *work; + fortran_int min_m_n = fortran_int_min(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_m = m; + size_t safe_n = n; + + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); + size_t tau_size = safe_min_m_n * sizeof(@ftyp@); + + fortran_int work_count; + size_t work_size; + fortran_int lda = fortran_int_max(1, m); + + mem_buff = malloc(a_size + tau_size); + + if (!mem_buff) + goto error; + + a = mem_buff; + tau = a + a_size; + memset(tau, 0, tau_size); + + + params->M = m; + params->N = n; + params->A = a; + params->TAU = tau; + params->LDA = lda; + + { + /* compute optimal work size */ + /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; + + params->WORK = &work_size_query; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query; + + work_size = (size_t) work_size_query * sizeof(@ftyp@); + } + + mem_buff2 = malloc(work_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + memset(work, 0, work_size); + + params->WORK = work; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + +/**begin repeat + #TYPE=CFLOAT,CDOUBLE# + #lapack_func=cgeqrf,zgeqrf# + #ftyp=fortran_complex,fortran_doublecomplex# + */ +static inline int +init_@lapack_func@(GEQRF_PARAMS_t *params, + fortran_int m, + fortran_int n) +{ + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *a, *tau, *work; + fortran_int min_m_n = fortran_int_min(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_m = m; + size_t safe_n = n; + + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); + size_t tau_size = safe_min_m_n * sizeof(@ftyp@); + + fortran_int work_count; + size_t work_size; + fortran_int lda = fortran_int_max(1, m); + + mem_buff = malloc(a_size + tau_size); + + if (!mem_buff) + goto error; + + a = mem_buff; + tau = a + tau_size; + memset(tau, 0, tau_size); + + + params->M = m; + params->N = n; + params->A = a; + params->TAU = tau; + params->LDA = lda; + + { + /* compute optimal work size */ + /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; + + params->WORK = &work_size_query; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query.r; + + work_size = (size_t) work_size_query.r * sizeof(@ftyp@); + } + + mem_buff2 = malloc(work_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + + params->WORK = work; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + +/**begin repeat + #lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf# + */ +static inline void +release_@lapack_func@(GEQRF_PARAMS_t* params) +{ + /* A and WORK contain allocated blocks */ + free(params->A); + free(params->WORK); + memset(params, 0, sizeof(*params)); +} + +/**end repeat**/ + +/**begin repeat + #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE# + #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE# + #lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf# + #dot_func=sdot,ddot,cdotc,zdotc# + #typ = npy_float, npy_double, npy_cfloat, npy_cdouble# + #basetyp = npy_float, npy_double, npy_float, npy_double# + #ftyp = fortran_real, fortran_doublereal, + fortran_complex, fortran_doublecomplex# + #cmplx = 0, 0, 1, 1# + */ +static void +@TYPE@_qr_r_raw_e(char **args, npy_intp const *dimensions, npy_intp const *steps, + void *NPY_UNUSED(func)) +{ + GEQRF_PARAMS_t params; + int error_occurred = get_fp_invalid_and_clear(); + fortran_int n, m; + + INIT_OUTER_LOOP_2 + + m = (fortran_int)dimensions[0]; + n = (fortran_int)dimensions[1]; + + if (init_@lapack_func@(¶ms, m, n)) { + LINEARIZE_DATA_t a_in, a_out, tau_out; + + init_linearize_data(&a_in, n, m, steps[1], steps[0]); + init_linearize_data(&a_out, m, n, steps[0], steps[1]); + init_linearize_data(&tau_out, 1, fortran_int_min(m, n), 1, steps[2]); + + BEGIN_OUTER_LOOP_2 + int not_ok; + linearize_@TYPE@_matrix(params.A, args[0], &a_in); + not_ok = call_@lapack_func@(¶ms); + if (!not_ok) { + delinearize_@TYPE@_matrix(args[0], params.A, &a_out); + delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_out); + } else { + error_occurred = 1; + nan_@TYPE@_matrix(args[0], &a_in); + nan_@TYPE@_matrix(args[1], &tau_out); + } + END_OUTER_LOOP + + release_@lapack_func@(¶ms); + } + + set_fp_invalid_or_clear(error_occurred); +} + +/**end repeat**/ + #pragma GCC diagnostic pop /* -------------------------------------------------------------------------- */ @@ -3306,6 +3606,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq); +GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw_e); GUFUNC_FUNC_ARRAY_EIG(eig); GUFUNC_FUNC_ARRAY_EIG(eigvals); @@ -3379,6 +3680,13 @@ static char lstsq_types[] = { NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, }; +static char qr_r_raw_e_types[] = { + NPY_FLOAT, NPY_FLOAT, + NPY_DOUBLE, NPY_DOUBLE, + NPY_CFLOAT, NPY_CFLOAT, + NPY_CDOUBLE, NPY_CDOUBLE, +}; + typedef struct gufunc_descriptor_struct { char *name; char *signature; @@ -3587,6 +3895,24 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = { 4, 3, 4, FUNC_ARRAY_NAME(lstsq), lstsq_types + }, + { + "qr_r_raw_e_m", + "(m,n)->(m)", + "Compute TAU vector for the last two dimensions \n"\ + "and broadcast to the rest. For m <= n. \n", + 4, 1, 1, + FUNC_ARRAY_NAME(qr_r_raw_e), + qr_r_raw_e_types + }, + { + "qr_r_raw_e_n", + "(m,n)->(n)", + "Compute TAU vector for the last two dimensions \n"\ + "and broadcast to the rest. For m > n. \n", + 4, 1, 1, + FUNC_ARRAY_NAME(qr_r_raw_e), + qr_r_raw_e_types } }; |