summaryrefslogtreecommitdiff
path: root/numpy/linalg/umath_linalg.c.src
diff options
context:
space:
mode:
authorczgdp1807 <gdp.1807@gmail.com>2021-06-04 11:59:36 +0530
committerczgdp1807 <gdp.1807@gmail.com>2021-06-04 11:59:36 +0530
commitd5cbb5e0eefec9c97a901f1efa3e1a40d1ec0b04 (patch)
treee4eb977c76457bec38cd0a7b351d66098c5bd371 /numpy/linalg/umath_linalg.c.src
parentc25c09b0688bcda1da148fbcac47adab697c409c (diff)
downloadnumpy-d5cbb5e0eefec9c97a901f1efa3e1a40d1ec0b04.tar.gz
WIP
Diffstat (limited to 'numpy/linalg/umath_linalg.c.src')
-rw-r--r--numpy/linalg/umath_linalg.c.src365
1 files changed, 349 insertions, 16 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index a2cc86dc4..045b70003 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -179,6 +179,23 @@ FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int
f2c_doublecomplex tau[], f2c_doublecomplex work[],
fortran_int *lwork, fortran_int *info);
+extern fortran_int
+FNAME(sorgqr)(fortran_int *m, fortran_int *n, fortran_int *k, float a[], fortran_int *lda,
+ float tau[], float work[],
+ fortran_int *lwork, fortran_int *info);
+extern fortran_int
+FNAME(dorgqr)(fortran_int *m, fortran_int *n, fortran_int *k, double a[], fortran_int *lda,
+ double tau[], double work[],
+ fortran_int *lwork, fortran_int *info);
+extern fortran_int
+FNAME(cungqr)(fortran_int *m, fortran_int *n, fortran_int *k, f2c_complex a[],
+ fortran_int *lda, f2c_complex tau[],
+ f2c_complex work[], fortran_int *lwork, fortran_int *info);
+extern fortran_int
+FNAME(zungqr)(fortran_int *m, fortran_int *n, fortran_int *k, f2c_doublecomplex a[],
+ fortran_int *lda, f2c_doublecomplex tau[],
+ f2c_doublecomplex work[], fortran_int *lwork, fortran_int *info);
+
extern fortran_int
FNAME(sgesv)(fortran_int *n, fortran_int *nrhs,
float a[], fortran_int *lda,
@@ -3253,7 +3270,7 @@ static void
/**end repeat**/
/* -------------------------------------------------------------------------- */
- /* qr (modes - r, raw, economic) */
+ /* qr (modes - r, raw) */
typedef struct geqfr_params_struct
{
@@ -3278,7 +3295,7 @@ dump_geqrf_params(const char *name,
"%14s: %18p\n"\
"%14s: %18d\n"\
"%14s: %18d\n"\
- "%14s: %18d\n"
+ "%14s: %18d\n"\
"%14s: %18d\n",
name,
@@ -3421,7 +3438,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
goto error;
a = mem_buff;
- tau = a + tau_size;
+ tau = a + a_size;
memset(tau, 0, tau_size);
@@ -3452,6 +3469,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
goto error;
work = mem_buff2;
+ memset(work, 0, work_size);
params->WORK = work;
params->LWORK = work_count;
@@ -3486,7 +3504,6 @@ release_@lapack_func@(GEQRF_PARAMS_t* params)
#TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
#REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
#lapack_func=sgeqrf,dgeqrf,cgeqrf,zgeqrf#
- #dot_func=sdot,ddot,cdotc,zdotc#
#typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
#basetyp = npy_float, npy_double, npy_float, npy_double#
#ftyp = fortran_real, fortran_doublereal,
@@ -3494,7 +3511,7 @@ release_@lapack_func@(GEQRF_PARAMS_t* params)
#cmplx = 0, 0, 1, 1#
*/
static void
-@TYPE@_qr_r_raw_e(char **args, npy_intp const *dimensions, npy_intp const *steps,
+@TYPE@_qr_r_raw(char **args, npy_intp const *dimensions, npy_intp const *steps,
void *NPY_UNUSED(func))
{
GEQRF_PARAMS_t params;
@@ -3507,10 +3524,9 @@ static void
n = (fortran_int)dimensions[1];
if (init_@lapack_func@(&params, m, n)) {
- LINEARIZE_DATA_t a_in, a_out, tau_out;
+ LINEARIZE_DATA_t a_in, tau_out;
init_linearize_data(&a_in, n, m, steps[1], steps[0]);
- init_linearize_data(&a_out, m, n, steps[0], steps[1]);
init_linearize_data(&tau_out, 1, fortran_int_min(m, n), 1, steps[2]);
BEGIN_OUTER_LOOP_2
@@ -3518,7 +3534,7 @@ static void
linearize_@TYPE@_matrix(params.A, args[0], &a_in);
not_ok = call_@lapack_func@(&params);
if (!not_ok) {
- delinearize_@TYPE@_matrix(args[0], params.A, &a_out);
+ delinearize_@TYPE@_matrix(args[0], params.A, &a_in);
delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_out);
} else {
error_occurred = 1;
@@ -3535,6 +3551,297 @@ static void
/**end repeat**/
+/* -------------------------------------------------------------------------- */
+ /* qr (modes - reduced) */
+
+typedef struct gqr_params_struct
+{
+ fortran_int M;
+ fortran_int MC;
+ fortran_int MN;
+ void *Q;
+ fortran_int LDA;
+ void* TAU;
+ void *WORK;
+ fortran_int LWORK;
+} GQR_PARAMS_t;
+
+
+static inline void
+dump_gqr_params(const char *name,
+ GQR_PARAMS_t *params)
+{
+ TRACE_TXT("\n%s:\n"\
+
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18p\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n"\
+ "%14s: %18d\n",
+
+ name,
+
+ "Q", params->Q,
+ "TAU", params->TAU,
+ "WORK", params->WORK,
+
+ "M", (int)params->M,
+ "MC", (int)params->MC,
+ "MN", (int)params->MN,
+ "LDA", (int)params->LDA,
+ "LWORK", (int)params->LWORK);
+}
+
+/**begin repeat
+ #lapack_func=sorgqr,dorgqr,cungqr,zungqr#
+ */
+
+static inline fortran_int
+call_@lapack_func@(GQR_PARAMS_t *params)
+{
+ fortran_int rv;
+ LAPACK(@lapack_func@)(&params->M, &params->MC, &params->MN,
+ params->Q, &params->LDA,
+ params->TAU,
+ params->WORK, &params->LWORK,
+ &rv);
+ return rv;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE#
+ #lapack_func=sorgqr,dorgqr#
+ #ftyp=fortran_real,fortran_doublereal#
+ */
+static inline int
+init_@lapack_func@(GQR_PARAMS_t *params,
+ void* a,
+ fortran_int m,
+ fortran_int n)
+{
+ printf("%zu %zu\n", m, n);
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *q, *tau, *work;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+
+ size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@);
+ size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
+
+ fortran_int work_count;
+ size_t work_size;
+ fortran_int lda = fortran_int_max(1, m);
+
+ mem_buff = malloc(q_size + tau_size);
+
+ if (!mem_buff)
+ goto error;
+
+ q = mem_buff;
+ memcpy(q, a, q_size);
+ tau = q + q_size;
+
+
+ params->M = m;
+ params->MC = min_m_n;
+ params->MN = min_m_n;
+ params->Q = q;
+ params->TAU = tau;
+ params->LDA = lda;
+
+ {
+ /* compute optimal work size */
+ /* IS THE FOLLOWING METHOD CORRECT? */
+ @ftyp@ work_size_query;
+
+ params->WORK = &work_size_query;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query;
+
+ work_size = (size_t) work_size_query * sizeof(@ftyp@);
+ }
+
+ mem_buff2 = malloc(work_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+ memset(work, 0, work_size);
+
+ params->WORK = work;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=CFLOAT,CDOUBLE#
+ #lapack_func=cungqr,zungqr#
+ #ftyp=fortran_complex,fortran_doublecomplex#
+ */
+static inline int
+init_@lapack_func@(GQR_PARAMS_t *params,
+ void* a,
+ fortran_int m,
+ fortran_int n)
+{
+ npy_uint8 *mem_buff = NULL;
+ npy_uint8 *mem_buff2 = NULL;
+ npy_uint8 *q, *tau, *work;
+ fortran_int min_m_n = fortran_int_min(m, n);
+ size_t safe_min_m_n = min_m_n;
+ size_t safe_m = m;
+ size_t safe_n = n;
+
+ size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@);
+ size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
+
+ fortran_int work_count;
+ size_t work_size;
+ fortran_int lda = fortran_int_max(1, m);
+
+ mem_buff = malloc(q_size + tau_size);
+
+ if (!mem_buff)
+ goto error;
+
+ q = mem_buff;
+ memcpy(q, a, q_size);
+ tau = q + q_size;
+
+
+ params->M = m;
+ params->MC = min_m_n;
+ params->MN = min_m_n;
+ params->Q = q;
+ params->TAU = tau;
+ params->LDA = lda;
+
+ {
+ /* compute optimal work size */
+ /* IS THE FOLLOWING METHOD CORRECT? */
+ @ftyp@ work_size_query;
+
+ params->WORK = &work_size_query;
+ params->LWORK = -1;
+
+ if (call_@lapack_func@(params) != 0)
+ goto error;
+
+ work_count = (fortran_int)work_size_query.r;
+
+ work_size = (size_t) work_size_query.r * sizeof(@ftyp@);
+ }
+
+ mem_buff2 = malloc(work_size);
+ if (!mem_buff2)
+ goto error;
+
+ work = mem_buff2;
+ memset(work, 0, work_size);
+
+ params->WORK = work;
+ params->LWORK = work_count;
+
+ return 1;
+ error:
+ TRACE_TXT("%s failed init\n", __FUNCTION__);
+ free(mem_buff);
+ free(mem_buff2);
+ memset(params, 0, sizeof(*params));
+
+ return 0;
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #lapack_func=sorgqr,dorgqr,cungqr,zungqr#
+ */
+static inline void
+release_@lapack_func@(GQR_PARAMS_t* params)
+{
+ /* A and WORK contain allocated blocks */
+ free(params->Q);
+ free(params->WORK);
+ memset(params, 0, sizeof(*params));
+}
+
+/**end repeat**/
+
+/**begin repeat
+ #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
+ #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
+ #lapack_func=sorgqr,dorgqr,cungqr,zungqr#
+ #typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
+ #basetyp = npy_float, npy_double, npy_float, npy_double#
+ #ftyp = fortran_real, fortran_doublereal,
+ fortran_complex, fortran_doublecomplex#
+ #cmplx = 0, 0, 1, 1#
+ */
+static void
+@TYPE@_qr_reduced(char **args, npy_intp const *dimensions, npy_intp const *steps,
+ void *NPY_UNUSED(func))
+{
+ GQR_PARAMS_t params;
+ int error_occurred = get_fp_invalid_and_clear();
+ fortran_int n, m;
+
+ INIT_OUTER_LOOP_3
+
+ m = (fortran_int)dimensions[0];
+ n = (fortran_int)dimensions[1];
+
+
+ if (init_@lapack_func@(&params, (void*)args[0], m, n)) {
+ LINEARIZE_DATA_t tau_in, q_out;
+
+ init_linearize_data(&tau_in, 1, fortran_int_min(m, n), 1, steps[2]);
+ init_linearize_data(&q_out, fortran_int_min(m, n), m, steps[4], steps[3]);
+
+ BEGIN_OUTER_LOOP_3
+ int not_ok;
+ linearize_@TYPE@_matrix(params.TAU, args[1], &tau_in);
+ not_ok = call_@lapack_func@(&params);
+ if (!not_ok) {
+ delinearize_@TYPE@_matrix(args[1], params.TAU, &tau_in);
+ delinearize_@TYPE@_matrix(args[2], params.Q, &q_out);
+ } else {
+ error_occurred = 1;
+ nan_@TYPE@_matrix(args[1], &tau_in);
+ nan_@TYPE@_matrix(args[2], &q_out);
+ }
+ END_OUTER_LOOP
+
+ release_@lapack_func@(&params);
+ }
+
+ set_fp_invalid_or_clear(error_occurred);
+}
+
+/**end repeat**/
+
#pragma GCC diagnostic pop
/* -------------------------------------------------------------------------- */
@@ -3606,7 +3913,8 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq);
-GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw_e);
+GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_r_raw);
+GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr_reduced);
GUFUNC_FUNC_ARRAY_EIG(eig);
GUFUNC_FUNC_ARRAY_EIG(eigvals);
@@ -3680,13 +3988,20 @@ static char lstsq_types[] = {
NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
};
-static char qr_r_raw_e_types[] = {
+static char qr_r_raw_types[] = {
NPY_FLOAT, NPY_FLOAT,
NPY_DOUBLE, NPY_DOUBLE,
NPY_CFLOAT, NPY_CFLOAT,
NPY_CDOUBLE, NPY_CDOUBLE,
};
+static char qr_reduced_types[] = {
+ NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
+ NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
+ NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT,
+ NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
+};
+
typedef struct gufunc_descriptor_struct {
char *name;
char *signature;
@@ -3897,22 +4212,40 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = {
lstsq_types
},
{
- "qr_r_raw_e_m",
+ "qr_r_raw_m",
"(m,n)->(m)",
"Compute TAU vector for the last two dimensions \n"\
"and broadcast to the rest. For m <= n. \n",
4, 1, 1,
- FUNC_ARRAY_NAME(qr_r_raw_e),
- qr_r_raw_e_types
+ FUNC_ARRAY_NAME(qr_r_raw),
+ qr_r_raw_types
},
{
- "qr_r_raw_e_n",
+ "qr_r_raw_n",
"(m,n)->(n)",
"Compute TAU vector for the last two dimensions \n"\
"and broadcast to the rest. For m > n. \n",
4, 1, 1,
- FUNC_ARRAY_NAME(qr_r_raw_e),
- qr_r_raw_e_types
+ FUNC_ARRAY_NAME(qr_r_raw),
+ qr_r_raw_types
+ },
+ {
+ "qr_reduced_m",
+ "(m,n),(m)->(m,m)",
+ "Compute Q matrix for the last two dimensions \n"\
+ "and broadcast to the rest. For m <= n. \n",
+ 4, 2, 1,
+ FUNC_ARRAY_NAME(qr_reduced),
+ qr_reduced_types
+ },
+ {
+ "qr_reduced_n",
+ "(m,n),(n)->(m,n)",
+ "Compute Q matrix for the last two dimensions \n"\
+ "and broadcast to the rest. For m > n. \n",
+ 4, 2, 1,
+ FUNC_ARRAY_NAME(qr_reduced),
+ qr_reduced_types
}
};