diff options
Diffstat (limited to 'numpy/linalg/umath_linalg.c.src')
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 84 |
1 files changed, 57 insertions, 27 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index fd074e983..e5b66e276 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -706,6 +706,22 @@ dump_@TYPE@_matrix(const char* name, ***************************************************************************** */ +static inline fortran_int +fortran_int_min(fortran_int x, fortran_int y) { + if (x < y) + return x; + else + return y; +} + +static inline fortran_int +fortran_int_max(fortran_int x, fortran_int y) { + if (x > y) + return x; + else + return y; +} + #define INIT_OUTER_LOOP_1 \ npy_intp dN = *dimensions++;\ npy_intp N_;\ @@ -1090,9 +1106,10 @@ static inline void @basetyp@ *logdet) { fortran_int info = 0; + fortran_int lda = fortran_int_max(m, 1); int i; /* note: done in place */ - LAPACK(@cblas_type@getrf)(&m, &m, (void *)src, &m, pivots, &info); + LAPACK(@cblas_type@getrf)(&m, &m, (void *)src, &lda, pivots, &info); if (info == 0) { @@ -1231,6 +1248,7 @@ typedef struct eigh_params_struct { fortran_int LIWORK; char JOBZ; char UPLO; + fortran_int LDA; } EIGH_PARAMS_t; /**begin repeat @@ -1258,6 +1276,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO, npy_uint8 *a, *w, *work, *iwork; size_t safe_N = N; size_t alloc_size = safe_N * (safe_N + 1) * sizeof(@typ@); + fortran_int lda = fortran_int_max(N, 1); mem_buff = malloc(alloc_size); @@ -1266,7 +1285,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO, a = mem_buff; w = mem_buff + safe_N * safe_N * sizeof(@typ@); LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N, - (@ftyp@*)a, &N, (@ftyp@*)w, + (@ftyp@*)a, &lda, (@ftyp@*)w, &query_work_size, &lwork, &query_iwork_size, &liwork, &info); @@ -1295,6 +1314,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO, params->LIWORK = liwork; params->JOBZ = JOBZ; params->UPLO = UPLO; + params->LDA = lda; return 1; @@ -1312,7 +1332,7 @@ call_@lapack_func@(EIGH_PARAMS_t *params) { fortran_int rv; LAPACK(@lapack_func@)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N, - params->A, ¶ms->N, params->W, + params->A, ¶ms->LDA, params->W, params->WORK, ¶ms->LWORK, params->IWORK, ¶ms->LIWORK, &rv); @@ -1350,6 +1370,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params, npy_uint8 *a, *w, *work, *rwork, *iwork; fortran_int info; size_t safe_N = N; + fortran_int lda = fortran_int_max(N, 1); mem_buff = malloc(safe_N * safe_N * sizeof(@typ@) + safe_N * sizeof(@basetyp@)); @@ -1359,7 +1380,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params, w = mem_buff + safe_N * safe_N * sizeof(@typ@); LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N, - (@ftyp@*)a, &N, (@fbasetyp@*)w, + (@ftyp@*)a, &lda, (@fbasetyp@*)w, &query_work_size, &lwork, &query_rwork_size, &lrwork, &query_iwork_size, &liwork, @@ -1391,6 +1412,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params, params->LIWORK = liwork; params->JOBZ = JOBZ; params->UPLO = UPLO; + params->LDA = lda; return 1; @@ -1408,7 +1430,7 @@ call_@lapack_func@(EIGH_PARAMS_t *params) { fortran_int rv; LAPACK(@lapack_func@)(¶ms->JOBZ, ¶ms->UPLO, ¶ms->N, - params->A, ¶ms->N, params->W, + params->A, ¶ms->LDA, params->W, params->WORK, ¶ms->LWORK, params->RWORK, ¶ms->LRWORK, params->IWORK, ¶ms->LIWORK, @@ -1590,6 +1612,7 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS) npy_uint8 *a, *b, *ipiv; size_t safe_N = N; size_t safe_NRHS = NRHS; + fortran_int ld = fortran_int_max(N, 1); mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@) + safe_N * safe_NRHS*sizeof(@ftyp@) + safe_N * sizeof(fortran_int)); @@ -1604,8 +1627,8 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS) params->IPIV = (fortran_int*)ipiv; params->N = N; params->NRHS = NRHS; - params->LDA = N; - params->LDB = N; + params->LDA = ld; + params->LDB = ld; return 1; error: @@ -1769,6 +1792,7 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N) npy_uint8 *mem_buff = NULL; npy_uint8 *a; size_t safe_N = N; + fortran_int lda = fortran_int_max(N, 1); mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@)); if (!mem_buff) @@ -1778,7 +1802,7 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N) params->A = a; params->N = N; - params->LDA = N; + params->LDA = lda; params->UPLO = UPLO; return 1; @@ -1947,6 +1971,7 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n) @typ@ work_size_query; fortran_int do_size_query = -1; fortran_int rv; + fortran_int ld = fortran_int_max(n, 1); /* allocate data for known sizes (all but work) */ mem_buff = malloc(a_size + wr_size + wi_size + @@ -1964,8 +1989,8 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n) vl = w + w_size; vr = vl + vl_size; LAPACK(@lapack_func@)(&jobvl, &jobvr, &n, - (void *)a, &n, (void *)wr, (void *)wi, - (void *)vl, &n, (void *)vr, &n, + (void *)a, &ld, (void *)wr, (void *)wi, + (void *)vl, &ld, (void *)vr, &ld, &work_size_query, &do_size_query, &rv); @@ -1988,9 +2013,9 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n) params->VL = vl; params->VR = vr; params->N = n; - params->LDA = n; - params->LDVL = n; - params->LDVR = n; + params->LDA = ld; + params->LDVL = ld; + params->LDVR = ld; params->LWORK = (fortran_int)work_count; params->JOBVL = jobvl; params->JOBVR = jobvr; @@ -2142,6 +2167,7 @@ init_@lapack_func@(GEEV_PARAMS_t* params, fortran_int do_size_query = -1; fortran_int rv; size_t total_size = a_size + w_size + vl_size + vr_size + rwork_size; + fortran_int ld = fortran_int_max(n, 1); mem_buff = malloc(total_size); if (!mem_buff) @@ -2153,8 +2179,8 @@ init_@lapack_func@(GEEV_PARAMS_t* params, vr = vl + vl_size; rwork = vr + vr_size; LAPACK(@lapack_func@)(&jobvl, &jobvr, &n, - (void *)a, &n, (void *)w, - (void *)vl, &n, (void *)vr, &n, + (void *)a, &ld, (void *)w, + (void *)vl, &ld, (void *)vr, &ld, (void *)&work_size_query, &do_size_query, (void *)rwork, &rv); @@ -2181,9 +2207,9 @@ init_@lapack_func@(GEEV_PARAMS_t* params, params->WORK = work; params->W = w; params->N = n; - params->LDA = n; - params->LDVL = n; - params->LDVR = n; + params->LDA = ld; + params->LDVL = ld; + params->LDVR = ld; params->LWORK = (fortran_int)work_count; params->JOBVL = jobvl; params->JOBVR = jobvr; @@ -2421,7 +2447,7 @@ compute_urows_vtcolumns(char jobz, fortran_int m, fortran_int n, fortran_int *urows, fortran_int *vtcolumns) { - fortran_int min_m_n = m<n?m:n; + fortran_int min_m_n = fortran_int_min(m, n); switch(jobz) { case 'N': @@ -2464,7 +2490,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params, size_t safe_m = m; size_t safe_n = n; size_t a_size = safe_m * safe_n * sizeof(@ftyp@); - fortran_int min_m_n = m<n?m:n; + fortran_int min_m_n = fortran_int_min(m, n); size_t safe_min_m_n = min_m_n; size_t s_size = safe_min_m_n * sizeof(@ftyp@); fortran_int u_row_count, vt_column_count; @@ -2473,6 +2499,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int work_count; size_t work_size; size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int); + fortran_int ld = fortran_int_max(m, 1); if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) goto error; @@ -2502,13 +2529,15 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int do_query = -1; fortran_int rv; LAPACK(@lapack_func@)(&jobz, &m, &n, - (void*)a, &m, (void*)s, (void*)u, &m, + (void*)a, &ld, (void*)s, (void*)u, &ld, (void*)vt, &vt_column_count, &work_size_query, &do_query, (void*)iwork, &rv); if (0!=rv) goto error; work_count = (fortran_int)work_size_query; + /* Fix a bug in lapack 3.0.0 */ + if(work_count == 0) work_count = 1; work_size = (size_t)work_count * sizeof(@ftyp@); } @@ -2529,8 +2558,8 @@ init_@lapack_func@(GESDD_PARAMS_t *params, params->IWORK = iwork; params->M = m; params->N = n; - params->LDA = m; - params->LDU = m; + params->LDA = ld; + params->LDU = ld; params->LDVT = vt_column_count; params->LWORK = work_count; params->JOBZ = jobz; @@ -2583,8 +2612,9 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int u_row_count, vt_column_count, work_count; size_t safe_m = m; size_t safe_n = n; - fortran_int min_m_n = m<n?m:n; + fortran_int min_m_n = fortran_int_min(m, n); size_t safe_min_m_n = min_m_n; + fortran_int ld = fortran_int_max(m, 1); if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count)) goto error; @@ -2626,7 +2656,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params, fortran_int do_query = -1; fortran_int rv; LAPACK(@lapack_func@)(&jobz, &m, &n, - (void*)a, &m, (void*)s, (void*)u, &m, + (void*)a, &ld, (void*)s, (void*)u, &ld, (void*)vt, &vt_column_count, &work_size_query, &do_query, (void*)rwork, @@ -2654,8 +2684,8 @@ init_@lapack_func@(GESDD_PARAMS_t *params, params->IWORK = iwork; params->M = m; params->N = n; - params->LDA = m; - params->LDU = m; + params->LDA = ld; + params->LDU = ld; params->LDVT = vt_column_count; params->LWORK = work_count; params->JOBZ = jobz; |