diff options
author | czgdp1807 <gdp.1807@gmail.com> | 2021-06-04 15:05:30 +0530 |
---|---|---|
committer | czgdp1807 <gdp.1807@gmail.com> | 2021-06-04 15:05:30 +0530 |
commit | a3b58bae0f5bb19b294669a6540e079885529fd6 (patch) | |
tree | 2f5131c68b02c3b4af8fb31d349ac927259fad38 /numpy | |
parent | d5cbb5e0eefec9c97a901f1efa3e1a40d1ec0b04 (diff) | |
download | numpy-a3b58bae0f5bb19b294669a6540e079885529fd6.tar.gz |
ready for formal addition of tests and documentation update
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 56 | ||||
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 279 |
2 files changed, 280 insertions, 55 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 97d130894..2031014ac 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -769,12 +769,12 @@ def cholesky(a): # QR decomposition -def _qr_dispatcher(a, mode=None, old=None): +def _qr_dispatcher(a, mode=None): return (a,) @array_function_dispatch(_qr_dispatcher) -def qr(a, mode='reduced', old=True): +def qr(a, mode='reduced'): """ Compute the qr factorization of a matrix. @@ -925,65 +925,33 @@ def qr(a, mode='reduced', old=True): # handle modes that don't return q if mode == 'r': - return wrap(triu(a.T)) + return wrap(triu(a)) if mode == 'raw': - return wrap(a.T), tau + return wrap(transpose(a)), tau if mode == 'economic': if t != result_t : a = a.astype(result_t, copy=False) return wrap(a) - if old: - # generate q from a - if mode == 'complete' and m > n: - mc = m - q = empty((m, m), t) - else: - mc = mn - q = empty((m, n), t) - q[:n] = a - - if isComplexType(t): - lapack_routine = lapack_lite.zungqr - routine_name = 'zungqr' + if mode == 'complete' and m > n: + mc = m + if m <= n: + gufunc = _umath_linalg.qr_complete_m else: - lapack_routine = lapack_lite.dorgqr - routine_name = 'dorgqr' - - # determine optimal lwork - lwork = 1 - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, -1, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - # compute q - lwork = max(1, n, int(abs(work[0]))) - work = zeros((lwork,), t) - results = lapack_routine(m, mc, mn, q, max(1, m), tau, work, lwork, 0) - if results['info'] != 0: - raise LinAlgError('%s returns %d' % (routine_name, results['info'])) - - q = _fastCopyAndTranspose(result_t, q[:mc]) - r = _fastCopyAndTranspose(result_t, a[:, :mc]) - - return wrap(q), wrap(triu(r)) - - if mode == 'reduced': + gufunc = _umath_linalg.qr_complete_n + else: + mc = mn if m <= n: gufunc = _umath_linalg.qr_reduced_m else: gufunc = _umath_linalg.qr_reduced_n - print(gufunc) signature = 'DD->D' if isComplexType(t) else 'dd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) q = gufunc(a, tau, signature=signature, extobj=extobj) - - r = _fastCopyAndTranspose(result_t, a[:, :mn]) - return q, wrap(triu(r)) + return wrap(q), wrap(triu(a[:, :mc])) # Eigenvalues diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index 045b70003..c58fbaa42 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -3559,6 +3559,7 @@ typedef struct gqr_params_struct fortran_int M; fortran_int MC; fortran_int MN; + void* A; void *Q; fortran_int LDA; void* TAU; @@ -3620,19 +3621,18 @@ call_@lapack_func@(GQR_PARAMS_t *params) */ static inline int init_@lapack_func@(GQR_PARAMS_t *params, - void* a, fortran_int m, fortran_int n) { - printf("%zu %zu\n", m, n); npy_uint8 *mem_buff = NULL; npy_uint8 *mem_buff2 = NULL; - npy_uint8 *q, *tau, *work; + npy_uint8 *a, *q, *tau, *work; fortran_int min_m_n = fortran_int_min(m, n); size_t safe_min_m_n = min_m_n; size_t safe_m = m; size_t safe_n = n; + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@); size_t tau_size = safe_min_m_n * sizeof(@ftyp@); @@ -3640,19 +3640,20 @@ init_@lapack_func@(GQR_PARAMS_t *params, size_t work_size; fortran_int lda = fortran_int_max(1, m); - mem_buff = malloc(q_size + tau_size); + mem_buff = malloc(q_size + tau_size + a_size); if (!mem_buff) goto error; q = mem_buff; - memcpy(q, a, q_size); tau = q + q_size; + a = tau + tau_size; params->M = m; params->MC = min_m_n; params->MN = min_m_n; + params->A = a; params->Q = q; params->TAU = tau; params->LDA = lda; @@ -3702,18 +3703,18 @@ init_@lapack_func@(GQR_PARAMS_t *params, */ static inline int init_@lapack_func@(GQR_PARAMS_t *params, - void* a, fortran_int m, fortran_int n) { npy_uint8 *mem_buff = NULL; npy_uint8 *mem_buff2 = NULL; - npy_uint8 *q, *tau, *work; + npy_uint8 *a, *q, *tau, *work; fortran_int min_m_n = fortran_int_min(m, n); size_t safe_min_m_n = min_m_n; size_t safe_m = m; size_t safe_n = n; + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@); size_t tau_size = safe_min_m_n * sizeof(@ftyp@); @@ -3721,19 +3722,20 @@ init_@lapack_func@(GQR_PARAMS_t *params, size_t work_size; fortran_int lda = fortran_int_max(1, m); - mem_buff = malloc(q_size + tau_size); + mem_buff = malloc(q_size + tau_size + a_size); if (!mem_buff) goto error; q = mem_buff; - memcpy(q, a, q_size); tau = q + q_size; + a = tau + tau_size; params->M = m; params->MC = min_m_n; params->MN = min_m_n; + params->A = a; params->Q = q; params->TAU = tau; params->LDA = lda; @@ -3814,14 +3816,243 @@ static void n = (fortran_int)dimensions[1]; - if (init_@lapack_func@(¶ms, (void*)args[0], m, n)) { - LINEARIZE_DATA_t tau_in, q_out; + if (init_@lapack_func@(¶ms, m, n)) { + LINEARIZE_DATA_t a_in, tau_in, q_out; + init_linearize_data(&a_in, n, m, steps[1], steps[0]); init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]); init_linearize_data(&q_out, fortran_int_min(m, n), m, steps[4], steps[3]); BEGIN_OUTER_LOOP_3 int not_ok; + linearize_@TYPE@_matrix(params.A, args[0], &a_in); + size_t safe_min_m_n = fortran_int_min(m, n); + size_t safe_m = m; + memcpy(params.Q, params.A, safe_min_m_n * safe_m * sizeof(@ftyp@)); + linearize_@TYPE@_matrix(params.TAU, args[1], &tau_in); + not_ok = call_@lapack_func@(¶ms); + if (!not_ok) { + delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_in); + delinearize_@TYPE@_matrix(args[2], params.Q, &q_out); + } else { + error_occurred = 1; + nan_@TYPE@_matrix(args[1], &tau_in); + nan_@TYPE@_matrix(args[2], &q_out); + } + END_OUTER_LOOP + + release_@lapack_func@(¶ms); + } + + set_fp_invalid_or_clear(error_occurred); +} + +/**end repeat**/ + +/* -------------------------------------------------------------------------- */ + /* qr (modes - complete) */ + +/**begin repeat + #TYPE=FLOAT,DOUBLE# + #lapack_func=sorgqr,dorgqr# + #ftyp=fortran_real,fortran_doublereal# + */ +static inline int +init_@lapack_func@_complete(GQR_PARAMS_t *params, + fortran_int m, + fortran_int n) +{ + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *a, *q, *tau, *work; + fortran_int min_m_n = fortran_int_min(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_m = m; + size_t safe_n = n; + + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); + size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@); + size_t tau_size = safe_min_m_n * sizeof(@ftyp@); + + fortran_int work_count; + size_t work_size; + fortran_int lda = fortran_int_max(1, m); + + mem_buff = malloc(q_size + tau_size + a_size); + + if (!mem_buff) + goto error; + + q = mem_buff; + tau = q + q_size; + a = tau + tau_size; + + + params->M = m; + params->MC = min_m_n; + params->MN = min_m_n; + params->A = a; + params->Q = q; + params->TAU = tau; + params->LDA = lda; + + { + /* compute optimal work size */ + /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; + + params->WORK = &work_size_query; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query; + + work_size = (size_t) work_size_query * sizeof(@ftyp@); + } + + mem_buff2 = malloc(work_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + memset(work, 0, work_size); + + params->WORK = work; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + +/**begin repeat + #TYPE=CFLOAT,CDOUBLE# + #lapack_func=cungqr,zungqr# + #ftyp=fortran_complex,fortran_doublecomplex# + */ +static inline int +init_@lapack_func@_complete(GQR_PARAMS_t *params, + fortran_int m, + fortran_int n) +{ + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *a, *q, *tau, *work; + fortran_int min_m_n = fortran_int_min(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_m = m; + size_t safe_n = n; + + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); + size_t q_size = safe_m * safe_m * sizeof(@ftyp@); + size_t tau_size = safe_min_m_n * sizeof(@ftyp@); + + fortran_int work_count; + size_t work_size; + fortran_int lda = fortran_int_max(1, m); + + mem_buff = malloc(q_size + tau_size + a_size); + + if (!mem_buff) + goto error; + + q = mem_buff; + tau = q + q_size; + a = tau + tau_size; + + + params->M = m; + params->MC = m; + params->MN = min_m_n; + params->A = a; + params->Q = q; + params->TAU = tau; + params->LDA = lda; + + { + /* compute optimal work size */ + /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; + + params->WORK = &work_size_query; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query.r; + + work_size = (size_t) work_size_query.r * sizeof(@ftyp@); + } + + mem_buff2 = malloc(work_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + memset(work, 0, work_size); + + params->WORK = work; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + +/**begin repeat + #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE# + #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE# + #lapack_func=sorgqr,dorgqr,cungqr,zungqr# + #typ = npy_float, npy_double, npy_cfloat, npy_cdouble# + #basetyp = npy_float, npy_double, npy_float, npy_double# + #ftyp = fortran_real, fortran_doublereal, + fortran_complex, fortran_doublecomplex# + #cmplx = 0, 0, 1, 1# + */ +static void +@TYPE@_qr_complete(char **args, npy_intp const *dimensions, npy_intp const *steps, + void *NPY_UNUSED(func)) +{ + GQR_PARAMS_t params; + int error_occurred = get_fp_invalid_and_clear(); + fortran_int n, m; + + INIT_OUTER_LOOP_3 + + m = (fortran_int)dimensions[0]; + n = (fortran_int)dimensions[1]; + + + if (init_@lapack_func@_complete(¶ms, m, n)) { + LINEARIZE_DATA_t a_in, tau_in, q_out; + + init_linearize_data(&a_in, n, m, steps[1], steps[0]); + init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]); + init_linearize_data(&q_out, m, m, steps[4], steps[3]); + + BEGIN_OUTER_LOOP_3 + int not_ok; + linearize_@TYPE@_matrix(params.A, args[0], &a_in); + size_t safe_n = n; + size_t safe_m = m; + memcpy(params.Q, params.A, safe_n * safe_m * sizeof(@ftyp@)); linearize_@TYPE@_matrix(params.TAU, args[1], &tau_in); not_ok = call_@lapack_func@(¶ms); if (!not_ok) { @@ -3915,6 +4146,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_reduced); +GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_complete); GUFUNC_FUNC_ARRAY_EIG(eig); GUFUNC_FUNC_ARRAY_EIG(eigvals); @@ -4002,6 +4234,13 @@ static char qr_reduced_types[] = { NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, }; +static char qr_complete_types[] = { + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, + NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT, + NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, +}; + typedef struct gufunc_descriptor_struct { char *name; char *signature; @@ -4246,6 +4485,24 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = { 4, 2, 1, FUNC_ARRAY_NAME(qr_reduced), qr_reduced_types + }, + { + "qr_complete_m", + "(m,n),(m)->(m,m)", + "Compute Q matrix for the last two dimensions \n"\ + "and broadcast to the rest. For m <= n. \n", + 4, 2, 1, + FUNC_ARRAY_NAME(qr_complete), + qr_complete_types + }, + { + "qr_complete_n", + "(m,n),(n)->(m,m)", + "Compute Q matrix for the last two dimensions \n"\ + "and broadcast to the rest. For m > n. \n", + 4, 2, 1, + FUNC_ARRAY_NAME(qr_complete), + qr_complete_types } }; |