summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-11-06 23:24:32 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-04-10 23:33:12 -0700
commitb5cdbd10180a9fcedc5f4751330064c6a03bc5e1 (patch)
tree404d012484eea7b7486436c41068ab81d2d0f067 /numpy
parentcf8e5c19b9451aab3a867e1ab364e51470e0b872 (diff)
downloadnumpy-b5cdbd10180a9fcedc5f4751330064c6a03bc5e1.tar.gz
ENH: Add wrapper functions to allocate workspaces
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/umath_linalg.c.src308
1 files changed, 308 insertions, 0 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index 07301f35c..77a205668 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -2885,6 +2885,314 @@ static void
/**end repeat**/
+
+/* -------------------------------------------------------------------------- */
+ /* least squares */
+
+typedef struct gelsd_params_struct
+{
+ fortran_int M;
+ fortran_int N;
+ fortran_int NRHS;
+ void *A;
+ fortran_int LDA;
+ void *B;
+ fortran_int LDB;
+ void *S;
+ void *RCOND;
+ fortran_int RANK;
+ void *WORK;
+ fortran_int LWORK;
+ void *RWORK;
+ void *IWORK;
+} GELSD_PARAMS_t;
+
+
+static inline void
+dump_gelsd_params(const char *name,
+ GELSD_PARAMS_t *params)
+{
+ TRACE_TXT("\n%s:\n"\
+
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+
+ "%14s: %18p\n",
+
+ name,
+
+ "A", params->A,
+ "B", params->B,
+ "S", params->S,
+ "WORK", params->WORK,
+ "RWORK", params->RWORK,
+ "IWORK", params->IWORK,
+
+ "M", (int)params->M,
+ "N", (int)params->N,
+ "NRHS", (int)params->NRHS,
+ "LDA", (int)params->LDA,
+ "LDB", (int)params->LDB,
+ "LWORK", (int)params->LWORK,
+ "RANK", (int)params->RANK,
+
+ "RCOND", params->RCOND);
+}
+
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE#
+ #lapack_func=sgelsd,dgelsd#
+ #ftyp=fortran_real,fortran_doublereal#
+ */
+
+static inline fortran_int
+call_@lapack_func@(GELSD_PARAMS_t *params)
+{
+ fortran_int rv;
+ LAPACK(@lapack_func@)(&params->M, &params->N, &params->NRHS,
+ params->A, &params->LDA,
+ params->B, &params->LDB,
+ params->S,
+ params->RCOND, &params->RANK,
+ params->WORK, &params->LWORK,
+ params->IWORK,
+ &rv);
+ return rv;
+}
+
+static inline int
+init_@lapack_func@(GELSD_PARAMS_t *params,
+ fortran_int m,
+ fortran_int n,
+ fortran_int nrhs)
+{
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *a, *b, *s, *work, *iwork;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ fortran_int max_m_n = fortran_int_max(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_max_m_n = max_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+ size_t safe_nrhs = nrhs;
+
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
+ size_t b_size = safe_max_m_n * safe_nrhs * sizeof(@ftyp@);
+ size_t s_size = safe_min_m_n * sizeof(@ftyp@);
+
+ fortran_int work_count;
+ size_t work_size;
+ size_t iwork_size;
+ fortran_int lda = fortran_int_max(1, m);
+ fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n));
+
+ mem_buff = malloc(a_size + b_size + s_size);
+
+ if (!mem_buff)
+ goto error;
+
+ a = mem_buff;
+ b = a + a_size;
+ s = b + b_size;
+
+
+ params->M = m;
+ params->N = n;
+ params->NRHS = nrhs;
+ params->A = a;
+ params->B = b;
+ params->S = s;
+ params->LDA = lda;
+ params->LDB = ldb;
+
+ {
+ /* compute optimal work size */
+ @ftyp@ work_size_query;
+ fortran_int iwork_size_query;
+
+ params->WORK = &work_size_query;
+ params->IWORK = &iwork_size_query;
+ params->RWORK = NULL;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query;
+
+ work_size = (size_t) work_size_query * sizeof(@ftyp@);
+ iwork_size = (size_t)iwork_size_query * sizeof(fortran_int);
+ }
+
+ mem_buff2 = malloc(work_size + iwork_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+ iwork = work + work_size;
+
+ params->WORK = work;
+ params->RWORK = NULL;
+ params->IWORK = iwork;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=CFLOAT,CDOUBLE#
+ #ftyp=fortran_complex,fortran_doublecomplex#
+ #frealtyp=fortran_real,fortran_doublereal#
+ #typ=COMPLEX_t,DOUBLECOMPLEX_t#
+ #lapack_func=cgelsd,zgelsd#
+ */
+
+static inline fortran_int
+call_@lapack_func@(GELSD_PARAMS_t *params)
+{
+ fortran_int rv;
+ LAPACK(@lapack_func@)(&params->M, &params->N, &params->NRHS,
+ params->A, &params->LDA,
+ params->B, &params->LDB,
+ params->S,
+ params->RCOND, &params->RANK,
+ params->WORK, &params->LWORK,
+ params->RWORK, params->IWORK,
+ &rv);
+ return rv;
+}
+
+static inline int
+init_@lapack_func@(GELSD_PARAMS_t *params,
+ fortran_int m,
+ fortran_int n,
+ fortran_int nrhs)
+{
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *a, *b, *s, *work, *iwork, *rwork;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ fortran_int max_m_n = fortran_int_max(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_max_m_n = max_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+ size_t safe_nrhs = nrhs;
+
+ size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
+ size_t b_size = safe_max_m_n * safe_nrhs * sizeof(@ftyp@);
+ size_t s_size = safe_min_m_n * sizeof(@frealtyp@);
+
+ fortran_int work_count;
+ size_t work_size, rwork_size, iwork_size;
+ fortran_int lda = fortran_int_max(1, m);
+ fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n));
+
+ mem_buff = malloc(a_size + b_size + s_size);
+
+ if (!mem_buff)
+ goto error;
+
+ a = mem_buff;
+ b = a + a_size;
+ s = b + b_size;
+
+
+ params->M = m;
+ params->N = n;
+ params->NRHS = nrhs;
+ params->A = a;
+ params->B = b;
+ params->S = s;
+ params->LDA = lda;
+ params->LDB = ldb;
+
+ {
+ /* compute optimal work size */
+ @ftyp@ work_size_query;
+ @frealtyp@ rwork_size_query;
+ fortran_int iwork_size_query;
+
+ params->WORK = &work_size_query;
+ params->IWORK = &iwork_size_query;
+ params->RWORK = &rwork_size_query;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query.r;
+
+ work_size = (size_t )work_size_query.r * sizeof(@ftyp@);
+ rwork_size = (size_t)rwork_size_query * sizeof(@frealtyp@);
+ iwork_size = (size_t)iwork_size_query * sizeof(fortran_int);
+ }
+
+ mem_buff2 = malloc(work_size + rwork_size + iwork_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+ rwork = work + work_size;
+ iwork = rwork + rwork_size;
+
+ params->WORK = work;
+ params->RWORK = rwork;
+ params->IWORK = iwork;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
+ #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
+ #lapack_func=sgelsd,dgelsd,cgelsd,zgelsd#
+ */
+static inline void
+release_@lapack_func@(GELSD_PARAMS_t* params)
+{
+ /* A and WORK contain allocated blocks */
+ free(params->A);
+ free(params->WORK);
+ memset(params, 0, sizeof(*params));
+}
+
+/**end repeat**/
+
#pragma GCC diagnostic pop
/* -------------------------------------------------------------------------- */