diff options
author | czgdp1807 <gdp.1807@gmail.com> | 2021-06-05 09:00:45 +0530 |
---|---|---|
committer | czgdp1807 <gdp.1807@gmail.com> | 2021-06-05 10:55:37 +0530 |
commit | c33cf14e144f8f3f8c226e81bbf693cdba7b3187 (patch) | |
tree | 215027962e21e793d85993d02eb66f16711103b1 /numpy | |
parent | 885ba91db5132d3a4b1dfd5f60234bfedb08b432 (diff) | |
download | numpy-c33cf14e144f8f3f8c226e81bbf693cdba7b3187.tar.gz |
existing tests passed
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 19 | ||||
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 62 |
2 files changed, 57 insertions, 24 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 2031014ac..631f5b369 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -925,10 +925,17 @@ def qr(a, mode='reduced'): # handle modes that don't return q if mode == 'r': - return wrap(triu(a)) + r = triu(a[..., :mn, :]) + if t != result_t: + r = r.astype(result_t, copy=False) + return wrap(r) if mode == 'raw': - return wrap(transpose(a)), tau + q = transpose(a) + if t != result_t: + q = q.astype(result_t, copy=False) + tau = tau.astype(result_t, copy=False) + return wrap(q), tau if mode == 'economic': if t != result_t : @@ -951,7 +958,13 @@ def qr(a, mode='reduced'): signature = 'DD->D' if isComplexType(t) else 'dd->d' extobj = get_linalg_error_extobj(_raise_linalgerror_qr_r_raw) q = gufunc(a, tau, signature=signature, extobj=extobj) - return wrap(q), wrap(triu(a[:, :mc])) + r = triu(a[..., :mc, :]) + + if t != result_t: + q = q.astype(result_t, copy=False) + r = r.astype(result_t, copy=False) + + return wrap(q), wrap(r) # Eigenvalues diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index a2f463d24..1f794a1e3 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -3372,6 +3372,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, { /* compute optimal work size */ /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; params->WORK = &work_size_query; @@ -3380,11 +3381,14 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, if (call_@lapack_func@(params) != 0) goto error; - work_count = (fortran_int)work_size_query; + work_count = (fortran_int) *(@ftyp@*) params->WORK; - work_size = (size_t) work_size_query * sizeof(@ftyp@); } + params->LWORK = fortran_int_max(fortran_int_max(1, n), + work_count > 0 ? work_count : -work_count); + + work_size = (size_t) params->LWORK * sizeof(@ftyp@); mem_buff2 = malloc(work_size); if (!mem_buff2) goto error; @@ -3393,7 +3397,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, memset(work, 0, work_size); params->WORK = work; - params->LWORK = work_count; return 1; error: @@ -3451,6 +3454,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, { /* compute optimal work size */ /* IS THE FOLLOWING METHOD CORRECT? */ + @ftyp@ work_size_query; params->WORK = &work_size_query; @@ -3459,11 +3463,15 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, if (call_@lapack_func@(params) != 0) goto error; - work_count = (fortran_int)work_size_query.r; + work_count = (fortran_int) ((@ftyp@*)params->WORK)->r; - work_size = (size_t) work_size_query.r * sizeof(@ftyp@); } + params->LWORK = fortran_int_max(fortran_int_max(1, n), + work_count > 0 ? work_count : -work_count); + + work_size = (size_t) params->LWORK * sizeof(@ftyp@); + mem_buff2 = malloc(work_size); if (!mem_buff2) goto error; @@ -3472,7 +3480,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, memset(work, 0, work_size); params->WORK = work; - params->LWORK = work_count; return 1; error: @@ -3631,7 +3638,6 @@ init_@lapack_func@(GQR_PARAMS_t *params, size_t safe_min_m_n = min_m_n; size_t safe_m = m; size_t safe_n = n; - size_t a_size = safe_m * safe_n * sizeof(@ftyp@); size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@); size_t tau_size = safe_min_m_n * sizeof(@ftyp@); @@ -3669,11 +3675,15 @@ init_@lapack_func@(GQR_PARAMS_t *params, if (call_@lapack_func@(params) != 0) goto error; - work_count = (fortran_int)work_size_query; + work_count = (fortran_int) *(@ftyp@*) params->WORK; - work_size = (size_t) work_size_query * sizeof(@ftyp@); } + params->LWORK = fortran_int_max(fortran_int_max(1, n), + work_count > 0 ? work_count : -work_count); + + work_size = (size_t) params->LWORK * sizeof(@ftyp@); + mem_buff2 = malloc(work_size); if (!mem_buff2) goto error; @@ -3682,7 +3692,6 @@ init_@lapack_func@(GQR_PARAMS_t *params, memset(work, 0, work_size); params->WORK = work; - params->LWORK = work_count; return 1; error: @@ -3751,11 +3760,15 @@ init_@lapack_func@(GQR_PARAMS_t *params, if (call_@lapack_func@(params) != 0) goto error; - work_count = (fortran_int)work_size_query.r; + work_count = (fortran_int) ((@ftyp@*)params->WORK)->r; - work_size = (size_t) work_size_query.r * sizeof(@ftyp@); } + params->LWORK = fortran_int_max(fortran_int_max(1, n), + work_count > 0 ? work_count : -work_count); + + work_size = (size_t) params->LWORK * sizeof(@ftyp@); + mem_buff2 = malloc(work_size); if (!mem_buff2) goto error; @@ -3862,6 +3875,7 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params, fortran_int m, fortran_int n) { + printf("Inside zungqr\n"); npy_uint8 *mem_buff = NULL; npy_uint8 *mem_buff2 = NULL; npy_uint8 *a, *q, *tau, *work; @@ -3871,7 +3885,7 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params, size_t safe_n = n; size_t a_size = safe_m * safe_n * sizeof(@ftyp@); - size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@); + size_t q_size = safe_m * safe_m * sizeof(@ftyp@); size_t tau_size = safe_min_m_n * sizeof(@ftyp@); fortran_int work_count; @@ -3889,7 +3903,7 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params, params->M = m; - params->MC = min_m_n; + params->MC = m; params->MN = min_m_n; params->A = a; params->Q = q; @@ -3907,11 +3921,15 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params, if (call_@lapack_func@(params) != 0) goto error; - work_count = (fortran_int)work_size_query; + work_count = (fortran_int) *(@ftyp@*) params->WORK; - work_size = (size_t) work_size_query * sizeof(@ftyp@); } + params->LWORK = fortran_int_max(fortran_int_max(1, n), + work_count > 0 ? work_count : -work_count); + + work_size = (size_t) params->LWORK * sizeof(@ftyp@); + mem_buff2 = malloc(work_size); if (!mem_buff2) goto error; @@ -3920,7 +3938,6 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params, memset(work, 0, work_size); params->WORK = work; - params->LWORK = work_count; return 1; error: @@ -3941,8 +3958,8 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params, */ static inline int init_@lapack_func@_complete(GQR_PARAMS_t *params, - fortran_int m, - fortran_int n) + fortran_int m, + fortran_int n) { npy_uint8 *mem_buff = NULL; npy_uint8 *mem_buff2 = NULL; @@ -3989,11 +4006,14 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params, if (call_@lapack_func@(params) != 0) goto error; - work_count = (fortran_int)work_size_query.r; + work_count = (fortran_int) ((@ftyp@*)params->WORK)->r; - work_size = (size_t) work_size_query.r * sizeof(@ftyp@); } + params->LWORK = fortran_int_max(fortran_int_max(1, n), + work_count > 0 ? work_count : -work_count); + + work_size = (size_t) params->LWORK * sizeof(@ftyp@); mem_buff2 = malloc(work_size); if (!mem_buff2) goto error; |