diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-08 09:58:04 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-03-24 20:37:54 +0000 |
commit | 92f1d1221c7d41609d853f7ce0ca666f341ff607 (patch) | |
tree | 0243270c0f39d5c270f8007bb58d38a047ce6c3b /numpy/linalg/umath_linalg.c.src | |
parent | 343ef94022294fa08f037462f45a7477942fba56 (diff) | |
download | numpy-92f1d1221c7d41609d853f7ce0ca666f341ff607.tar.gz |
BUG: Correctly compute the LDA parameters for lapack calls
These are documented in the lapack source code as always being one or greater.
LD stands for "leading dimension", and can be thought of as a stride.
Diffstat (limited to 'numpy/linalg/umath_linalg.c.src')
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 84 |
1 files changed, 57 insertions, 27 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index fd074e983..e5b66e276 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -706,6 +706,22 @@ dump_@TYPE@_matrix(const char* name, ***************************************************************************** */ +static inline fortran_int +fortran_int_min(fortran_int x, fortran_int y) { + if (x < y) + return x; + else + return y; +} + +static inline fortran_int +fortran_int_max(fortran_int x, fortran_int y) { + if (x > y) + return x; + else + return y; +} + #define INIT_OUTER_LOOP_1 \ npy_intp dN = *dimensions++;\ npy_intp N_;\ @@ -1090,9 +1106,10 @@ static inline void @basetyp@ *logdet) { fortran_int info = 0; + fortran_int lda = fortran_int_max(m, 1); int i; /* note: done in place */ - LAPACK(@cblas_type@getrf)(&m, &m, (void *)src, &m, pivots, &info); + LAPACK(@cblas_type@getrf)(&m, &m, (void *)src, &lda, pivots, &info); if (info == 0) { @@ -1231,6 +1248,7 @@ typedef struct eigh_params_struct { fortran_int LIWORK; char JOBZ; char UPLO; + fortran_int LDA; } EIGH_PARAMS_t; /**begin repeat @@ -1258,6 +1276,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO, npy_uint8 *a, *w, *work, *iwork; size_t safe_N = N; size_t alloc_size = safe_N * (safe_N + 1) * sizeof(@typ@); + fortran_int lda = fortran_int_max(N, 1); mem_buff = malloc(alloc_size); @@ -1266,7 +1285,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO, a = mem_buff; w = mem_buff + safe_N * safe_N * sizeof(@typ@); LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N, - (@ftyp@*)a, &N, (@ftyp@*)w, + (@ftyp@*)a, &lda, (@ftyp@*)w, &query_work_size, &lwork, &query_iwork_size, &liwork, &info); @@ -1295,6 +1314,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO, params->LIWORK = liwork; params->JOBZ = JOBZ; params->UPLO = UPLO; + params->LDA = lda; return 1; @@ -1312,7 +1332,7 @@ call_@lapack_func@(EIGH_PARAMS_t *params) { fortran_int rv; LAPACK(@lapack_func@)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N, - params->A, ¶ms->N, params->W, + params->A, ¶ms->LDA, params->W, params->WORK, ¶ms->LWORK, params->IWORK, ¶ms->LIWORK, &rv); @@ -1350,6 +1370,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params, npy_uint8 *a, *w, *work, *rwork, *iwork; fortran_int info; size_t safe_N = N; + fortran_int lda = fortran_int_max(N, 1); mem_buff = malloc(safe_N * safe_N * sizeof(@typ@) + safe_N * sizeof(@basetyp@)); @@ -1359,7 +1380,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params, w = mem_buff + safe_N * safe_N * sizeof(@typ@); LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N, - (@ftyp@*)a, &N, (@fbasetyp@*)w, + (@ftyp@*)a, &lda, (@fbasetyp@*)w, &query_work_size, &lwork, &query_rwork_size, &lrwork, &query_iwork_size, &liwork, @@ -1391,6 +1412,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params, params->LIWORK = liwork; params->JOBZ = JOBZ; params->UPLO = UPLO; + params->LDA = lda; return 1; @@ -1408,7 +1430,7 @@ call_@lapack_func@(EIGH_PARAMS_t *params) { fortran_int rv; LAPACK(@lapack_func@)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N, - params->A, ¶ms->N, params->W, + params->A, ¶ms->LDA, params->W, params->WORK, ¶ms->LWORK, params->RWORK, ¶ms->LRWORK, params->IWORK, ¶ms->LIWORK, @@ -1590,6 +1612,7 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS) npy_uint8 *a, *b, *ipiv; size_t safe_N = N; size_t safe_NRHS = NRHS; + fortran_int ld = fortran_int_max(N, 1); mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@) + safe_N * safe_NRHS*sizeof(@ftyp@) + safe_N * sizeof(fortran_int)); @@ -1604,8 +1627,8 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS) params->IPIV = (fortran_int*)ipiv; params->N = N; params->NRHS = NRHS; - params->LDA = N; - params->LDB = N; + params->LDA = ld; + params->LDB = ld; return 1; error: @@ -1769,6 +1792,7 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N) npy_uint8 *mem_buff = NULL; npy_uint8 *a; size_t safe_N = N; + fortran_int lda = fortran_int_max(N, 1); mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@)); if (!mem_buff) @@ -1778,7 +1802,7 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N) params->A = a; params->N = N; - params->LDA = N; + params->LDA = lda; params->UPLO = UPLO; return 1; @@ -1947,6 +1971,7 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n) @typ@ work_size_query; fortran_int do_size_query = -1; fortran_int rv; + fortran_int ld = fortran_int_max(n, 1); /* allocate data for known sizes (all but work) */ mem_buff = malloc(a_size + wr_size + wi_size + @@ -1964,8 +1989,8 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n) vl = w + w_size; vr = vl + vl_size; LAPACK(@lapack_func@)(&jobvl, &jobvr, &n, - (void *)a, &n, (void *)wr, (void *)wi, - (void *)vl, &n, (void *)vr, &n, + (void *)a, &ld, (void *)wr, (void *)wi, + (void *)vl, &ld, (void *)vr, &ld, &work_size_query, &do_size_query, &rv); @@ -1988,9 +2013,9 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n) params->VL = vl; params->VR = vr; params->N = n; - params->LDA = n; - params->LDVL = n; - params->LDVR = n; + params->LDA = ld; + params->LDVL = ld; + params->LDVR = ld; params->LWORK = (fortran_int)work_count; params->JOBVL = jobvl; params->JOBVR = jobvr; @@ -2142,6 +2167,7 @@ init_@lapack_func@(GEEV_PARAMS_t* params, fortran_int do_size_query = -1; fortran_int rv; size_t total_size = a_size + w_size + vl_size + vr_size + rwork_size; + fortran_int ld = fortran_int_max(n, 1); mem_buff = malloc(total_size); if (!mem_buff) @@ -2153,8 +2179,8 @@ init_@lapack_func@(GEEV_PARAMS_t* params, vr = vl + vl_size; rwork = vr + vr_size; LAPACK(@lapack_func@)(&jobvl, &jobvr, &n, - (void *)a, &n, (void *)w, - (void *)vl, &n, (void *)vr, &n, + (void *)a, &ld, (void *)w, + (void *)vl, &ld, (void *)vr, &ld, (void *)&work_size_query, &do_size_query, (void *)rwork, &rv); @@ -2181,9 +2207,9 @@ init_@lapack_func@(GEEV_PARAMS_t* params, params->WORK = work; params->W = w; params->N = n; - params->LDA = n; - params->LDVL = n; - params->LDVR = n; + params->LDA = ld; + params->LDVL = ld; + params->LDVR = ld; params->LWORK = (fortran_int)work_count; params->JOBVL = jobvl; params->JOBVR = jobvr; @@ -2421,7 +2447,7 @@ compute_urows_vtcolumns(char jobz, fortran_int m, fortran_int n, fortran_int *urows, fortran_int *vtcolumns) { - fortran_int min_m_n = m<n?m:n; + fortran_int min_m_n = fortran_int_min(m, n); switch(jobz) { case 'N': @@ -2464,7 +2490,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params, size_t safe_m = m; size_t safe_n = n; size_t a_size = safe_m * safe_n * sizeof(@ftyp@); - fortran_int min_m_n = m<n?m:n; + fortran_int min_m_n = fortran_int_min(m, n); size_t safe_min_m_n = min_m_n; size_t s_size = safe_min_m_n * sizeof(@ftyp@); fortran_int u_row_count, vt_column_count; @@ -2473,6 +2499,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int work_count; size_t work_size; size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int); + fortran_int ld = fortran_int_max(m, 1); if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) goto error; @@ -2502,13 +2529,15 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int do_query = -1; fortran_int rv; LAPACK(@lapack_func@)(&jobz, &m, &n, - (void*)a, &m, (void*)s, (void*)u, &m, + (void*)a, &ld, (void*)s, (void*)u, &ld, (void*)vt, &vt_column_count, &work_size_query, &do_query, (void*)iwork, &rv); if (0!=rv) goto error; work_count = (fortran_int)work_size_query; + /* Fix a bug in lapack 3.0.0 */ + if(work_count == 0) work_count = 1; work_size = (size_t)work_count * sizeof(@ftyp@); } @@ -2529,8 +2558,8 @@ init_@lapack_func@(GESDD_PARAMS_t *params, params->IWORK = iwork; params->M = m; params->N = n; - params->LDA = m; - params->LDU = m; + params->LDA = ld; + params->LDU = ld; params->LDVT = vt_column_count; params->LWORK = work_count; params->JOBZ = jobz; @@ -2583,8 +2612,9 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int u_row_count, vt_column_count, work_count; size_t safe_m = m; size_t safe_n = n; - fortran_int min_m_n = m<n?m:n; + fortran_int min_m_n = fortran_int_min(m, n); size_t safe_min_m_n = min_m_n; + fortran_int ld = fortran_int_max(m, 1); if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) goto error; @@ -2626,7 +2656,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int do_query = -1; fortran_int rv; LAPACK(@lapack_func@)(&jobz, &m, &n, - (void*)a, &m, (void*)s, (void*)u, &m, + (void*)a, &ld, (void*)s, (void*)u, &ld, (void*)vt, &vt_column_count, &work_size_query, &do_query, (void*)rwork, @@ -2654,8 +2684,8 @@ init_@lapack_func@(GESDD_PARAMS_t *params, params->IWORK = iwork; params->M = m; params->N = n; - params->LDA = m; - params->LDU = m; + params->LDA = ld; + params->LDU = ld; params->LDVT = vt_column_count; params->LWORK = work_count; params->JOBZ = jobz; |