diff options
author | czgdp1807 <gdp.1807@gmail.com> | 2021-06-04 11:59:36 +0530 |
---|---|---|
committer | czgdp1807 <gdp.1807@gmail.com> | 2021-06-04 11:59:36 +0530 |
commit | d5cbb5e0eefec9c97a901f1efa3e1a40d1ec0b04 (patch) | |
tree | e4eb977c76457bec38cd0a7b351d66098c5bd371 /numpy | |
parent | c25c09b0688bcda1da148fbcac47adab697c409c (diff) | |
download | numpy-d5cbb5e0eefec9c97a901f1efa3e1a40d1ec0b04.tar.gz |
WIP
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 100 | ||||
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 365 |
2 files changed, 407 insertions, 58 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 087156bd0..97d130894 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -99,6 +99,9 @@ def _raise_linalgerror_svd_nonconvergence(err, flag): def _raise_linalgerror_lstsq(err, flag): raise LinAlgError("SVD did not converge in Linear Least Squares") +def _raise_linalgerror_qr_r_raw(err, flag): + raise LinAlgError("Illegal argument found while performing QR factorization") + def get_linalg_error_extobj(callback): extobj = list(_linalg_error_extobj) # make a copy extobj[2] = callback @@ -766,12 +769,12 @@ def cholesky(a): # QR decomposition -def _qr_dispatcher(a, mode=None): +def _qr_dispatcher(a, mode=None, old=None): return (a,) @array_function_dispatch(_qr_dispatcher) -def qr(a, mode='reduced'): +def qr(a, mode='reduced', old=True): """ Compute the qr factorization of a matrix. @@ -912,62 +915,75 @@ def qr(a, mode='reduced'): mn = min(m, n) if m <= n: - gufunc = _umath_linalg.qr_r_raw_e_m + gufunc = _umath_linalg.qr_r_raw_m else: - gufunc = _umath_linalg.qr_r_raw_e_n + gufunc = _umath_linalg.qr_r_raw_n signature = 'D->D' if isComplexType(t) else 'd->d' - extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) + extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) tau = gufunc(a, signature=signature, extobj=extobj) # handle modes that don't return q if mode == 'r': - r = _fastCopyAndTranspose(result_t, a[:, :mn]) - return wrap(triu(r)) + return wrap(triu(a.T)) if mode == 'raw': - return a, tau + return wrap(a.T), tau if mode == 'economic': if t != result_t : a = a.astype(result_t, copy=False) - return wrap(a.T) - - # generate q from a - if mode == 'complete' and m > n: - mc = m - q = empty((m, m), t) - else: - mc = mn - q = empty((n, m), t) - q[:n] = a - - if isComplexType(t): - lapack_routine = lapack_lite.zungqr - routine_name = 'zungqr' - else: - lapack_routine = lapack_lite.dorgqr - routine_name = 'dorgqr' + return wrap(a) - # determine optimal lwork - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - # compute q - lwork = max(1, n, int(abs(work[0]))) - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - q = _fastCopyAndTranspose(result_t, q[:mc]) - r = _fastCopyAndTranspose(result_t, a[:, :mc]) + if old: + # generate q from a + if mode == 'complete' and m > n: + mc = m + q = empty((m, m), t) + else: + mc = mn + q = empty((m, n), t) + q[:n] = a - return wrap(q), wrap(triu(r)) + if isComplexType(t): + lapack_routine = lapack_lite.zungqr + routine_name = 'zungqr' + else: + lapack_routine = lapack_lite.dorgqr + routine_name = 'dorgqr' + + # determine optimal lwork + lwork = 1 + work = zeros((lwork,), t) + results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0) + if results['info'] != 0: + raise LinAlgError('%s returns %d' % (routine_name, results['info'])) + + # compute q + lwork = max(1, n, int(abs(work[0]))) + work = zeros((lwork,), t) + results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0) + if results['info'] != 0: + raise LinAlgError('%s returns %d' % (routine_name, results['info'])) + + q = _fastCopyAndTranspose(result_t, q[:mc]) + r = _fastCopyAndTranspose(result_t, a[:, :mc]) + + return wrap(q), wrap(triu(r)) + + if mode == 'reduced': + if m <= n: + gufunc = _umath_linalg.qr_reduced_m + else: + gufunc = _umath_linalg.qr_reduced_n + + print(gufunc) + signature = 'DD->D' if isComplexType(t) else 'dd->d' + extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) + q = gufunc(a, tau, signature=signature, extobj=extobj) + r = _fastCopyAndTranspose(result_t, a[:, :mn]) + return q, wrap(triu(r)) # Eigenvalues diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index a2cc86dc4..045b70003 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -179,6 +179,23 @@ FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int f2c_doublecomplex tau[], f2c_doublecomplex work[], fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(sorgqr)(fortran_int *m, fortran_int *n, fortran_int *k, float a[], fortran_int *lda, + float tau[], float work[], + fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(dorgqr)(fortran_int *m, fortran_int *n, fortran_int *k, double a[], fortran_int *lda, + double tau[], double work[], + fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(cungqr)(fortran_int *m, fortran_int *n, fortran_int *k, f2c_complex a[], + fortran_int *lda, f2c_complex tau[], + f2c_complex work[], fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(zungqr)(fortran_int *m, fortran_int *n, fortran_int *k, f2c_doublecomplex a[], + fortran_int *lda, f2c_doublecomplex tau[], + f2c_doublecomplex work[], fortran_int *lwork, fortran_int *info); + extern fortran_int FNAME(sgesv)(fortran_int *n, fortran_int *nrhs, float a[], fortran_int *lda, @@ -3253,7 +3270,7 @@ static void /**end repeat**/ /* -------------------------------------------------------------------------- */ - /* qr (modes - r, raw, economic) */ + /* qr (modes - r, raw) */ typedef struct geqfr_params_struct { @@ -3278,7 +3295,7 @@ dump_geqrf_params(const char *name, "%14s: %18p\n"\ "%14s: %18d\n"\ "%14s: %18d\n"\ - "%14s: %18d\n" + "%14s: %18d\n"\ "%14s: %18d\n", name, @@ -3421,7 +3438,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, goto error; a = mem_buff; - tau = a + tau_size; + tau = a + a_size; memset(tau, 0, tau_size); @@ -3452,6 +3469,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, goto error; work = mem_buff2; + memset(work, 0, work_size); params->WORK = work; params->LWORK = work_count; @@ -3486,7 +3504,6 @@ release_@lapack_func@(GEQRF_PARAMS_t* params) #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE# #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE# #lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf# - #dot_func=sdot,ddot,cdotc,zdotc# #typ = npy_float, npy_double, npy_cfloat, npy_cdouble# #basetyp = npy_float, npy_double, npy_float, npy_double# #ftyp = fortran_real, fortran_doublereal, @@ -3494,7 +3511,7 @@ release_@lapack_func@(GEQRF_PARAMS_t* params) #cmplx = 0, 0, 1, 1# */ static void -@TYPE@_qr_r_raw_e(char **args, npy_intp const *dimensions, npy_intp const *steps, +@TYPE@_qr_r_raw(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) { GEQRF_PARAMS_t params; @@ -3507,10 +3524,9 @@ static void n = (fortran_int)dimensions[1]; if (init_@lapack_func@(¶ms, m, n)) { - LINEARIZE_DATA_t a_in, a_out, tau_out; + LINEARIZE_DATA_t a_in, tau_out; init_linearize_data(&a_in, n, m, steps[1], steps[0]); - init_linearize_data(&a_out, m, n, steps[0], steps[1]); init_linearize_data(&tau_out, 1, fortran_int_min(m, n), 1, steps[2]); BEGIN_OUTER_LOOP_2 @@ -3518,7 +3534,7 @@ static void linearize_@TYPE@_matrix(params.A, args[0], &a_in); not_ok = call_@lapack_func@(¶ms); if (!not_ok) { - delinearize_@TYPE@_matrix(args[0], params.A, &a_out); + delinearize_@TYPE@_matrix(args[0], params.A, &a_in); delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_out); } else { error_occurred = 1; @@ -3535,6 +3551,297 @@ static void /**end repeat**/ +/* -------------------------------------------------------------------------- */ + /* qr (modes - reduced) */ + +typedef struct gqr_params_struct +{ + fortran_int M; + fortran_int MC; + fortran_int MN; + void *Q; + fortran_int LDA; + void* TAU; + void *WORK; + fortran_int LWORK; +} GQR_PARAMS_t; + + +static inline void +dump_gqr_params(const char *name, + GQR_PARAMS_t *params) +{ + TRACE_TXT("\n%s:\n"\ + + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n", + + name, + + "Q", params->Q, + "TAU", params->TAU, + "WORK", params->WORK, + + "M", (int)params->M, + "MC", (int)params->MC, + "MN", (int)params->MN, + "LDA", (int)params->LDA, + "LWORK", (int)params->LWORK); +} + +/**begin repeat + #lapack_func=sorgqr,dorgqr,cungqr,zungqr# + */ + +static inline fortran_int +call_@lapack_func@(GQR_PARAMS_t *params) +{ + fortran_int rv; + LAPACK(@lapack_func@)(¶ms->M, ¶ms->MC, ¶ms->MN, + params->Q, ¶ms->LDA, + params->TAU, + params->WORK, ¶ms->LWORK, + &rv); + return rv; +} + +/**end repeat**/ + +/**begin repeat + #TYPE=FLOAT,DOUBLE# + #lapack_func=sorgqr,dorgqr# + #ftyp=fortran_real,fortran_doublereal# + */ +static inline int +init_@lapack_func@(GQR_PARAMS_t *params, + void* a, + fortran_int m, + fortran_int n) +{ + printf("%zu %zu\n", m, n); + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *q, *tau, *work; + fortran_int min_m_n = fortran_int_min(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_m = m; + size_t safe_n = n; + + size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@); + size_t tau_size = safe_min_m_n * sizeof(@ftyp@); + + fortran_int work_count; + size_t work_size; + fortran_int lda = fortran_int_max(1, m); + + mem_buff = malloc(q_size + tau_size); + + if (!mem_buff) + goto error; + + q = mem_buff; + memcpy(q, a, q_size); + tau = q + q_size; + + + params->M = m; + params->MC = min_m_n; + params->MN = min_m_n; + params->Q = q; + params->TAU = tau; + params->LDA = lda; + + { + /* compute optimal work size */ + /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; + + params->WORK = &work_size_query; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query; + + work_size = (size_t) work_size_query * sizeof(@ftyp@); + } + + mem_buff2 = malloc(work_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + memset(work, 0, work_size); + + params->WORK = work; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + +/**begin repeat + #TYPE=CFLOAT,CDOUBLE# + #lapack_func=cungqr,zungqr# + #ftyp=fortran_complex,fortran_doublecomplex# + */ +static inline int +init_@lapack_func@(GQR_PARAMS_t *params, + void* a, + fortran_int m, + fortran_int n) +{ + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *q, *tau, *work; + fortran_int min_m_n = fortran_int_min(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_m = m; + size_t safe_n = n; + + size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@); + size_t tau_size = safe_min_m_n * sizeof(@ftyp@); + + fortran_int work_count; + size_t work_size; + fortran_int lda = fortran_int_max(1, m); + + mem_buff = malloc(q_size + tau_size); + + if (!mem_buff) + goto error; + + q = mem_buff; + memcpy(q, a, q_size); + tau = q + q_size; + + + params->M = m; + params->MC = min_m_n; + params->MN = min_m_n; + params->Q = q; + params->TAU = tau; + params->LDA = lda; + + { + /* compute optimal work size */ + /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; + + params->WORK = &work_size_query; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query.r; + + work_size = (size_t) work_size_query.r * sizeof(@ftyp@); + } + + mem_buff2 = malloc(work_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + memset(work, 0, work_size); + + params->WORK = work; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + +/**begin repeat + #lapack_func=sorgqr,dorgqr,cungqr,zungqr# + */ +static inline void +release_@lapack_func@(GQR_PARAMS_t* params) +{ + /* A and WORK contain allocated blocks */ + free(params->Q); + free(params->WORK); + memset(params, 0, sizeof(*params)); +} + +/**end repeat**/ + +/**begin repeat + #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE# + #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE# + #lapack_func=sorgqr,dorgqr,cungqr,zungqr# + #typ = npy_float, npy_double, npy_cfloat, npy_cdouble# + #basetyp = npy_float, npy_double, npy_float, npy_double# + #ftyp = fortran_real, fortran_doublereal, + fortran_complex, fortran_doublecomplex# + #cmplx = 0, 0, 1, 1# + */ +static void +@TYPE@_qr_reduced(char **args, npy_intp const *dimensions, npy_intp const *steps, + void *NPY_UNUSED(func)) +{ + GQR_PARAMS_t params; + int error_occurred = get_fp_invalid_and_clear(); + fortran_int n, m; + + INIT_OUTER_LOOP_3 + + m = (fortran_int)dimensions[0]; + n = (fortran_int)dimensions[1]; + + + if (init_@lapack_func@(¶ms, (void*)args[0], m, n)) { + LINEARIZE_DATA_t tau_in, q_out; + + init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]); + init_linearize_data(&q_out, fortran_int_min(m, n), m, steps[4], steps[3]); + + BEGIN_OUTER_LOOP_3 + int not_ok; + linearize_@TYPE@_matrix(params.TAU, args[1], &tau_in); + not_ok = call_@lapack_func@(¶ms); + if (!not_ok) { + delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_in); + delinearize_@TYPE@_matrix(args[2], params.Q, &q_out); + } else { + error_occurred = 1; + nan_@TYPE@_matrix(args[1], &tau_in); + nan_@TYPE@_matrix(args[2], &q_out); + } + END_OUTER_LOOP + + release_@lapack_func@(¶ms); + } + + set_fp_invalid_or_clear(error_occurred); +} + +/**end repeat**/ + #pragma GCC diagnostic pop /* -------------------------------------------------------------------------- */ @@ -3606,7 +3913,8 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq); -GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw_e); +GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw); +GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_reduced); GUFUNC_FUNC_ARRAY_EIG(eig); GUFUNC_FUNC_ARRAY_EIG(eigvals); @@ -3680,13 +3988,20 @@ static char lstsq_types[] = { NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, }; -static char qr_r_raw_e_types[] = { +static char qr_r_raw_types[] = { NPY_FLOAT, NPY_FLOAT, NPY_DOUBLE, NPY_DOUBLE, NPY_CFLOAT, NPY_CFLOAT, NPY_CDOUBLE, NPY_CDOUBLE, }; +static char qr_reduced_types[] = { + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, + NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT, + NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, +}; + typedef struct gufunc_descriptor_struct { char *name; char *signature; @@ -3897,22 +4212,40 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = { lstsq_types }, { - "qr_r_raw_e_m", + "qr_r_raw_m", "(m,n)->(m)", "Compute TAU vector for the last two dimensions \n"\ "and broadcast to the rest. For m <= n. \n", 4, 1, 1, - FUNC_ARRAY_NAME(qr_r_raw_e), - qr_r_raw_e_types + FUNC_ARRAY_NAME(qr_r_raw), + qr_r_raw_types }, { - "qr_r_raw_e_n", + "qr_r_raw_n", "(m,n)->(n)", "Compute TAU vector for the last two dimensions \n"\ "and broadcast to the rest. For m > n. \n", 4, 1, 1, - FUNC_ARRAY_NAME(qr_r_raw_e), - qr_r_raw_e_types + FUNC_ARRAY_NAME(qr_r_raw), + qr_r_raw_types + }, + { + "qr_reduced_m", + "(m,n),(m)->(m,m)", + "Compute Q matrix for the last two dimensions \n"\ + "and broadcast to the rest. For m <= n. \n", + 4, 2, 1, + FUNC_ARRAY_NAME(qr_reduced), + qr_reduced_types + }, + { + "qr_reduced_n", + "(m,n),(n)->(m,n)", + "Compute Q matrix for the last two dimensions \n"\ + "and broadcast to the rest. For m > n. \n", + 4, 2, 1, + FUNC_ARRAY_NAME(qr_reduced), + qr_reduced_types } }; |