summaryrefslogtreecommitdiff
path: root/numpy/linalg/umath_linalg.c.src
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/umath_linalg.c.src')
-rw-r--r--numpy/linalg/umath_linalg.c.src84
1 files changed, 57 insertions, 27 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index fd074e983..e5b66e276 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -706,6 +706,22 @@ dump_@TYPE@_matrix(const char* name,
*****************************************************************************
*/
+static inline fortran_int
+fortran_int_min(fortran_int x, fortran_int y) {
+ if (x < y)
+ return x;
+ else
+ return y;
+}
+
+static inline fortran_int
+fortran_int_max(fortran_int x, fortran_int y) {
+ if (x > y)
+ return x;
+ else
+ return y;
+}
+
#define INIT_OUTER_LOOP_1 \
npy_intp dN = *dimensions++;\
npy_intp N_;\
@@ -1090,9 +1106,10 @@ static inline void
@basetyp@ *logdet)
{
fortran_int info = 0;
+ fortran_int lda = fortran_int_max(m, 1);
int i;
/* note: done in place */
- LAPACK(@cblas_type@getrf)(&m, &m, (void *)src, &m, pivots, &info);
+ LAPACK(@cblas_type@getrf)(&m, &m, (void *)src, &lda, pivots, &info);
if (info == 0)
{
@@ -1231,6 +1248,7 @@ typedef struct eigh_params_struct {
fortran_int LIWORK;
char JOBZ;
char UPLO;
+ fortran_int LDA;
} EIGH_PARAMS_t;
/**begin repeat
@@ -1258,6 +1276,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO,
npy_uint8 *a, *w, *work, *iwork;
size_t safe_N = N;
size_t alloc_size = safe_N * (safe_N + 1) * sizeof(@typ@);
+ fortran_int lda = fortran_int_max(N, 1);
mem_buff = malloc(alloc_size);
@@ -1266,7 +1285,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO,
a = mem_buff;
w = mem_buff + safe_N * safe_N * sizeof(@typ@);
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
- (@ftyp@*)a, &N, (@ftyp@*)w,
+ (@ftyp@*)a, &lda, (@ftyp@*)w,
&query_work_size, &lwork,
&query_iwork_size, &liwork,
&info);
@@ -1295,6 +1314,7 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO,
params->LIWORK = liwork;
params->JOBZ = JOBZ;
params->UPLO = UPLO;
+ params->LDA = lda;
return 1;
@@ -1312,7 +1332,7 @@ call_@lapack_func@(EIGH_PARAMS_t *params)
{
fortran_int rv;
LAPACK(@lapack_func@)(&params->JOBZ, &params->UPLO, &params->N,
- params->A, &params->N, params->W,
+ params->A, &params->LDA, params->W,
params->WORK, &params->LWORK,
params->IWORK, &params->LIWORK,
&rv);
@@ -1350,6 +1370,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params,
npy_uint8 *a, *w, *work, *rwork, *iwork;
fortran_int info;
size_t safe_N = N;
+ fortran_int lda = fortran_int_max(N, 1);
mem_buff = malloc(safe_N * safe_N * sizeof(@typ@) +
safe_N * sizeof(@basetyp@));
@@ -1359,7 +1380,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params,
w = mem_buff + safe_N * safe_N * sizeof(@typ@);
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
- (@ftyp@*)a, &N, (@fbasetyp@*)w,
+ (@ftyp@*)a, &lda, (@fbasetyp@*)w,
&query_work_size, &lwork,
&query_rwork_size, &lrwork,
&query_iwork_size, &liwork,
@@ -1391,6 +1412,7 @@ init_@lapack_func@(EIGH_PARAMS_t *params,
params->LIWORK = liwork;
params->JOBZ = JOBZ;
params->UPLO = UPLO;
+ params->LDA = lda;
return 1;
@@ -1408,7 +1430,7 @@ call_@lapack_func@(EIGH_PARAMS_t *params)
{
fortran_int rv;
LAPACK(@lapack_func@)(&params->JOBZ, &params->UPLO, &params->N,
- params->A, &params->N, params->W,
+ params->A, &params->LDA, params->W,
params->WORK, &params->LWORK,
params->RWORK, &params->LRWORK,
params->IWORK, &params->LIWORK,
@@ -1590,6 +1612,7 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS)
npy_uint8 *a, *b, *ipiv;
size_t safe_N = N;
size_t safe_NRHS = NRHS;
+ fortran_int ld = fortran_int_max(N, 1);
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@) +
safe_N * safe_NRHS*sizeof(@ftyp@) +
safe_N * sizeof(fortran_int));
@@ -1604,8 +1627,8 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS)
params->IPIV = (fortran_int*)ipiv;
params->N = N;
params->NRHS = NRHS;
- params->LDA = N;
- params->LDB = N;
+ params->LDA = ld;
+ params->LDB = ld;
return 1;
error:
@@ -1769,6 +1792,7 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N)
npy_uint8 *mem_buff = NULL;
npy_uint8 *a;
size_t safe_N = N;
+ fortran_int lda = fortran_int_max(N, 1);
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@));
if (!mem_buff)
@@ -1778,7 +1802,7 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N)
params->A = a;
params->N = N;
- params->LDA = N;
+ params->LDA = lda;
params->UPLO = UPLO;
return 1;
@@ -1947,6 +1971,7 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n)
@typ@ work_size_query;
fortran_int do_size_query = -1;
fortran_int rv;
+ fortran_int ld = fortran_int_max(n, 1);
/* allocate data for known sizes (all but work) */
mem_buff = malloc(a_size + wr_size + wi_size +
@@ -1964,8 +1989,8 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n)
vl = w + w_size;
vr = vl + vl_size;
LAPACK(@lapack_func@)(&jobvl, &jobvr, &n,
- (void *)a, &n, (void *)wr, (void *)wi,
- (void *)vl, &n, (void *)vr, &n,
+ (void *)a, &ld, (void *)wr, (void *)wi,
+ (void *)vl, &ld, (void *)vr, &ld,
&work_size_query, &do_size_query,
&rv);
@@ -1988,9 +2013,9 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n)
params->VL = vl;
params->VR = vr;
params->N = n;
- params->LDA = n;
- params->LDVL = n;
- params->LDVR = n;
+ params->LDA = ld;
+ params->LDVL = ld;
+ params->LDVR = ld;
params->LWORK = (fortran_int)work_count;
params->JOBVL = jobvl;
params->JOBVR = jobvr;
@@ -2142,6 +2167,7 @@ init_@lapack_func@(GEEV_PARAMS_t* params,
fortran_int do_size_query = -1;
fortran_int rv;
size_t total_size = a_size + w_size + vl_size + vr_size + rwork_size;
+ fortran_int ld = fortran_int_max(n, 1);
mem_buff = malloc(total_size);
if (!mem_buff)
@@ -2153,8 +2179,8 @@ init_@lapack_func@(GEEV_PARAMS_t* params,
vr = vl + vl_size;
rwork = vr + vr_size;
LAPACK(@lapack_func@)(&jobvl, &jobvr, &n,
- (void *)a, &n, (void *)w,
- (void *)vl, &n, (void *)vr, &n,
+ (void *)a, &ld, (void *)w,
+ (void *)vl, &ld, (void *)vr, &ld,
(void *)&work_size_query, &do_size_query,
(void *)rwork,
&rv);
@@ -2181,9 +2207,9 @@ init_@lapack_func@(GEEV_PARAMS_t* params,
params->WORK = work;
params->W = w;
params->N = n;
- params->LDA = n;
- params->LDVL = n;
- params->LDVR = n;
+ params->LDA = ld;
+ params->LDVL = ld;
+ params->LDVR = ld;
params->LWORK = (fortran_int)work_count;
params->JOBVL = jobvl;
params->JOBVR = jobvr;
@@ -2421,7 +2447,7 @@ compute_urows_vtcolumns(char jobz,
fortran_int m, fortran_int n,
fortran_int *urows, fortran_int *vtcolumns)
{
- fortran_int min_m_n = m<n?m:n;
+ fortran_int min_m_n = fortran_int_min(m, n);
switch(jobz)
{
case 'N':
@@ -2464,7 +2490,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
size_t safe_m = m;
size_t safe_n = n;
size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
- fortran_int min_m_n = m<n?m:n;
+ fortran_int min_m_n = fortran_int_min(m, n);
size_t safe_min_m_n = min_m_n;
size_t s_size = safe_min_m_n * sizeof(@ftyp@);
fortran_int u_row_count, vt_column_count;
@@ -2473,6 +2499,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
fortran_int work_count;
size_t work_size;
size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int);
+ fortran_int ld = fortran_int_max(m, 1);
if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;
@@ -2502,13 +2529,15 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
fortran_int do_query = -1;
fortran_int rv;
LAPACK(@lapack_func@)(&jobz, &m, &n,
- (void*)a, &m, (void*)s, (void*)u, &m,
+ (void*)a, &ld, (void*)s, (void*)u, &ld,
(void*)vt, &vt_column_count,
&work_size_query, &do_query,
(void*)iwork, &rv);
if (0!=rv)
goto error;
work_count = (fortran_int)work_size_query;
+ /* Fix a bug in lapack 3.0.0 */
+ if(work_count == 0) work_count = 1;
work_size = (size_t)work_count * sizeof(@ftyp@);
}
@@ -2529,8 +2558,8 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
params->IWORK = iwork;
params->M = m;
params->N = n;
- params->LDA = m;
- params->LDU = m;
+ params->LDA = ld;
+ params->LDU = ld;
params->LDVT = vt_column_count;
params->LWORK = work_count;
params->JOBZ = jobz;
@@ -2583,8 +2612,9 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
fortran_int u_row_count, vt_column_count, work_count;
size_t safe_m = m;
size_t safe_n = n;
- fortran_int min_m_n = m<n?m:n;
+ fortran_int min_m_n = fortran_int_min(m, n);
size_t safe_min_m_n = min_m_n;
+ fortran_int ld = fortran_int_max(m, 1);
if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;
@@ -2626,7 +2656,7 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
fortran_int do_query = -1;
fortran_int rv;
LAPACK(@lapack_func@)(&jobz, &m, &n,
- (void*)a, &m, (void*)s, (void*)u, &m,
+ (void*)a, &ld, (void*)s, (void*)u, &ld,
(void*)vt, &vt_column_count,
&work_size_query, &do_query,
(void*)rwork,
@@ -2654,8 +2684,8 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
params->IWORK = iwork;
params->M = m;
params->N = n;
- params->LDA = m;
- params->LDU = m;
+ params->LDA = ld;
+ params->LDU = ld;
params->LDVT = vt_column_count;
params->LWORK = work_count;
params->JOBZ = jobz;