diff options
Diffstat (limited to 'numpy/linalg/umath_linalg.c.src')
-rw-r--r-- | numpy/linalg/umath_linalg.c.src | 424 |
1 files changed, 416 insertions, 8 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index 0248518ac..d8cfdf6ac 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -129,12 +129,26 @@ FNAME(zheevd)(char *jobz, char *uplo, int *n, int *info); extern int +FNAME(sgelsd)(int *m, int *n, int *nrhs, + float a[], int *lda, float b[], int *ldb, + float s[], float *rcond, int *rank, + float work[], int *lwork, int iwork[], + int *info); +extern int FNAME(dgelsd)(int *m, int *n, int *nrhs, double a[], int *lda, double b[], int *ldb, double s[], double *rcond, int *rank, double work[], int *lwork, int iwork[], int *info); extern int +FNAME(cgelsd)(int *m, int *n, int *nrhs, + f2c_complex a[], int *lda, + f2c_complex b[], int *ldb, + float s[], float *rcond, int *rank, + f2c_complex work[], int *lwork, + float rwork[], int iwork[], + int *info); +extern int FNAME(zgelsd)(int *m, int *n, int *nrhs, f2c_doublecomplex a[], int *lda, f2c_doublecomplex b[], int *ldb, @@ -492,6 +506,7 @@ static void init_constants(void) * columns: number of columns in the matrix * row_strides: the number bytes between consecutive rows. * column_strides: the number of bytes between consecutive columns. + * output_lead_dim: BLAS/LAPACK-side leading dimension, in elements */ typedef struct linearize_data_struct { @@ -499,19 +514,33 @@ typedef struct linearize_data_struct npy_intp columns; npy_intp row_strides; npy_intp column_strides; + npy_intp output_lead_dim; } LINEARIZE_DATA_t; static NPY_INLINE void +init_linearize_data_ex(LINEARIZE_DATA_t *lin_data, + npy_intp rows, + npy_intp columns, + npy_intp row_strides, + npy_intp column_strides, + npy_intp output_lead_dim) +{ + lin_data->rows = rows; + lin_data->columns = columns; + lin_data->row_strides = row_strides; + lin_data->column_strides = column_strides; + lin_data->output_lead_dim = output_lead_dim; +} + +static NPY_INLINE void init_linearize_data(LINEARIZE_DATA_t *lin_data, npy_intp rows, npy_intp columns, npy_intp row_strides, npy_intp column_strides) { - lin_data->rows = rows; - lin_data->columns = columns; - lin_data->row_strides = row_strides; - lin_data->column_strides = column_strides; + init_linearize_data_ex( + lin_data, rows, columns, row_strides, column_strides, columns); } static NPY_INLINE void @@ -846,7 +875,7 @@ linearize_@TYPE@_matrix(void *dst_in, } } src += data->row_strides/sizeof(@typ@); - dst += data->columns; + dst += data->output_lead_dim; } return rv; } else { @@ -893,7 +922,7 @@ delinearize_@TYPE@_matrix(void *dst_in, sizeof(@typ@)); } } - src += data->columns; + src += data->output_lead_dim; dst += data->row_strides/sizeof(@typ@); } @@ -2871,6 +2900,359 @@ static void /**end repeat**/ + +/* -------------------------------------------------------------------------- */ + /* least squares */ + +typedef struct gelsd_params_struct +{ + fortran_int M; + fortran_int N; + fortran_int NRHS; + void *A; + fortran_int LDA; + void *B; + fortran_int LDB; + void *S; + void *RCOND; + fortran_int RANK; + void *WORK; + fortran_int LWORK; + void *RWORK; + void *IWORK; +} GELSD_PARAMS_t; + + +static inline void +dump_gelsd_params(const char *name, + GELSD_PARAMS_t *params) +{ + TRACE_TXT("\n%s:\n"\ + + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18p\n"\ + "%14s: %18p\n"\ + + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + + "%14s: %18p\n", + + name, + + "A", params->A, + "B", params->B, + "S", params->S, + "WORK", params->WORK, + "RWORK", params->RWORK, + "IWORK", params->IWORK, + + "M", (int)params->M, + "N", (int)params->N, + "NRHS", (int)params->NRHS, + "LDA", (int)params->LDA, + "LDB", (int)params->LDB, + "LWORK", (int)params->LWORK, + "RANK", (int)params->RANK, + + "RCOND", params->RCOND); +} + + +/**begin repeat + #TYPE=FLOAT,DOUBLE# + #lapack_func=sgelsd,dgelsd# + #ftyp=fortran_real,fortran_doublereal# + */ + +static inline fortran_int +call_@lapack_func@(GELSD_PARAMS_t *params) +{ + fortran_int rv; + LAPACK(@lapack_func@)(¶ms->M, ¶ms->N, ¶ms->NRHS, + params->A, ¶ms->LDA, + params->B, ¶ms->LDB, + params->S, + params->RCOND, ¶ms->RANK, + params->WORK, ¶ms->LWORK, + params->IWORK, + &rv); + return rv; +} + +static inline int +init_@lapack_func@(GELSD_PARAMS_t *params, + fortran_int m, + fortran_int n, + fortran_int nrhs) +{ + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *a, *b, *s, *work, *iwork; + fortran_int min_m_n = fortran_int_min(m, n); + fortran_int max_m_n = fortran_int_max(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_max_m_n = max_m_n; + size_t safe_m = m; + size_t safe_n = n; + size_t safe_nrhs = nrhs; + + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); + size_t b_size = safe_max_m_n * safe_nrhs * sizeof(@ftyp@); + size_t s_size = safe_min_m_n * sizeof(@ftyp@); + + fortran_int work_count; + size_t work_size; + size_t iwork_size; + fortran_int lda = fortran_int_max(1, m); + fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); + + mem_buff = malloc(a_size + b_size + s_size); + + if (!mem_buff) + goto error; + + a = mem_buff; + b = a + a_size; + s = b + b_size; + + + params->M = m; + params->N = n; + params->NRHS = nrhs; + params->A = a; + params->B = b; + params->S = s; + params->LDA = lda; + params->LDB = ldb; + + { + /* compute optimal work size */ + @ftyp@ work_size_query; + fortran_int iwork_size_query; + + params->WORK = &work_size_query; + params->IWORK = &iwork_size_query; + params->RWORK = NULL; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query; + + work_size = (size_t) work_size_query * sizeof(@ftyp@); + iwork_size = (size_t)iwork_size_query * sizeof(fortran_int); + } + + mem_buff2 = malloc(work_size + iwork_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + iwork = work + work_size; + + params->WORK = work; + params->RWORK = NULL; + params->IWORK = iwork; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + +/**begin repeat + #TYPE=CFLOAT,CDOUBLE# + #ftyp=fortran_complex,fortran_doublecomplex# + #frealtyp=fortran_real,fortran_doublereal# + #typ=COMPLEX_t,DOUBLECOMPLEX_t# + #lapack_func=cgelsd,zgelsd# + */ + +static inline fortran_int +call_@lapack_func@(GELSD_PARAMS_t *params) +{ + fortran_int rv; + LAPACK(@lapack_func@)(¶ms->M, ¶ms->N, ¶ms->NRHS, + params->A, ¶ms->LDA, + params->B, ¶ms->LDB, + params->S, + params->RCOND, ¶ms->RANK, + params->WORK, ¶ms->LWORK, + params->RWORK, params->IWORK, + &rv); + return rv; +} + +static inline int +init_@lapack_func@(GELSD_PARAMS_t *params, + fortran_int m, + fortran_int n, + fortran_int nrhs) +{ + npy_uint8 *mem_buff = NULL; + npy_uint8 *mem_buff2 = NULL; + npy_uint8 *a, *b, *s, *work, *iwork, *rwork; + fortran_int min_m_n = fortran_int_min(m, n); + fortran_int max_m_n = fortran_int_max(m, n); + size_t safe_min_m_n = min_m_n; + size_t safe_max_m_n = max_m_n; + size_t safe_m = m; + size_t safe_n = n; + size_t safe_nrhs = nrhs; + + size_t a_size = safe_m * safe_n * sizeof(@ftyp@); + size_t b_size = safe_max_m_n * safe_nrhs * sizeof(@ftyp@); + size_t s_size = safe_min_m_n * sizeof(@frealtyp@); + + fortran_int work_count; + size_t work_size, rwork_size, iwork_size; + fortran_int lda = fortran_int_max(1, m); + fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); + + mem_buff = malloc(a_size + b_size + s_size); + + if (!mem_buff) + goto error; + + a = mem_buff; + b = a + a_size; + s = b + b_size; + + + params->M = m; + params->N = n; + params->NRHS = nrhs; + params->A = a; + params->B = b; + params->S = s; + params->LDA = lda; + params->LDB = ldb; + + { + /* compute optimal work size */ + @ftyp@ work_size_query; + @frealtyp@ rwork_size_query; + fortran_int iwork_size_query; + + params->WORK = &work_size_query; + params->IWORK = &iwork_size_query; + params->RWORK = &rwork_size_query; + params->LWORK = -1; + + if (call_@lapack_func@(params) != 0) + goto error; + + work_count = (fortran_int)work_size_query.r; + + work_size = (size_t )work_size_query.r * sizeof(@ftyp@); + rwork_size = (size_t)rwork_size_query * sizeof(@frealtyp@); + iwork_size = (size_t)iwork_size_query * sizeof(fortran_int); + } + + mem_buff2 = malloc(work_size + rwork_size + iwork_size); + if (!mem_buff2) + goto error; + + work = mem_buff2; + rwork = work + work_size; + iwork = rwork + rwork_size; + + params->WORK = work; + params->RWORK = rwork; + params->IWORK = iwork; + params->LWORK = work_count; + + return 1; + error: + TRACE_TXT("%s failed init\n", __FUNCTION__); + free(mem_buff); + free(mem_buff2); + memset(params, 0, sizeof(*params)); + + return 0; +} + +/**end repeat**/ + + +/**begin repeat + #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE# + #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE# + #lapack_func=sgelsd,dgelsd,cgelsd,zgelsd# + */ +static inline void +release_@lapack_func@(GELSD_PARAMS_t* params) +{ + /* A and WORK contain allocated blocks */ + free(params->A); + free(params->WORK); + memset(params, 0, sizeof(*params)); +} + +static void +@TYPE@_lstsq(char **args, npy_intp *dimensions, npy_intp *steps, + void *NPY_UNUSED(func)) +{ + GELSD_PARAMS_t params; + int error_occurred = get_fp_invalid_and_clear(); + fortran_int n, m, nrhs; + INIT_OUTER_LOOP_6 + + m = (fortran_int)dimensions[0]; + n = (fortran_int)dimensions[1]; + nrhs = (fortran_int)dimensions[2]; + + if (init_@lapack_func@(¶ms, m, n, nrhs)) { + LINEARIZE_DATA_t a_in, b_in, x_out, s_out; + + init_linearize_data(&a_in, n, m, steps[1], steps[0]); + init_linearize_data_ex(&b_in, nrhs, m, steps[3], steps[2], fortran_int_max(n, m)); + init_linearize_data(&x_out, nrhs, fortran_int_max(n, m), steps[5], steps[4]); + init_linearize_data(&s_out, 1, fortran_int_min(n, m), 1, steps[6]); + + BEGIN_OUTER_LOOP_6 + int not_ok; + linearize_@TYPE@_matrix(params.A, args[0], &a_in); + linearize_@TYPE@_matrix(params.B, args[1], &b_in); + params.RCOND = args[2]; + not_ok = call_@lapack_func@(¶ms); + if (!not_ok) { + delinearize_@TYPE@_matrix(args[3], params.B, &x_out); + *(npy_int*) args[4] = params.RANK; + delinearize_@REALTYPE@_matrix(args[5], params.S, &s_out); + } else { + error_occurred = 1; + nan_@TYPE@_matrix(args[3], &x_out); + *(npy_int*) args[4] = -1; + nan_@REALTYPE@_matrix(args[5], &s_out); + } + END_OUTER_LOOP + + release_@lapack_func@(¶ms); + } + + set_fp_invalid_or_clear(error_occurred); +} + +/**end repeat**/ + #pragma GCC diagnostic pop /* -------------------------------------------------------------------------- */ @@ -2941,6 +3323,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(cholesky_lo); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A); +GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq); GUFUNC_FUNC_ARRAY_EIG(eig); GUFUNC_FUNC_ARRAY_EIG(eigvals); @@ -3006,6 +3389,14 @@ static char svd_1_3_types[] = { NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE }; +/* A, b, rcond, x, rank, s */ +static char lstsq_types[] = { + NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT, + NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE, + NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, NPY_INT, NPY_FLOAT, + NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_INT, NPY_DOUBLE +}; + typedef struct gufunc_descriptor_struct { char *name; char *signature; @@ -3192,12 +3583,29 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = { "eigvals", "(m,m)->(m)", "eigvals on the last two dimension and broadcast to the rest. \n"\ - "Results in a vector of eigenvalues. \n"\ - " \"(m,m)->(m),(m,m)\" \n", + "Results in a vector of eigenvalues. \n", 3, 1, 1, FUNC_ARRAY_NAME(eigvals), eigvals_types }, + { + "lstsq_m", + "(m,n),(m,nrhs),()->(n,nrhs),(),(m)", + "least squares on the last two dimensions and broadcast to the rest. \n"\ + "For m <= n. \n", + 4, 3, 3, + FUNC_ARRAY_NAME(lstsq), + lstsq_types + }, + { + "lstsq_n", + "(m,n),(m,nrhs),()->(m,nrhs),(),(n)", + "least squares on the last two dimensions and broadcast to the rest. \n"\ + "For m >= n. \n", + 4, 3, 3, + FUNC_ARRAY_NAME(lstsq), + lstsq_types + } }; static void |