summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-11-06 23:25:33 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-04-10 23:33:12 -0700
commitefc254ed833a5f7a68c2dbb69868f5bee693e8e9 (patch)
tree58aea13875a2f7116c87110072a75595e5758cae
parentb5cdbd10180a9fcedc5f4751330064c6a03bc5e1 (diff)
downloadnumpy-efc254ed833a5f7a68c2dbb69868f5bee693e8e9.tar.gz
ENH: Add raw ufuncs to interface to gelsd functions
-rw-r--r--numpy/linalg/umath_linalg.c.src102
1 files changed, 94 insertions, 8 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index 77a205668..d8cfdf6ac 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -506,6 +506,7 @@ static void init_constants(void)
* columns: number of columns in the matrix
* row_strides: the number bytes between consecutive rows.
* column_strides: the number of bytes between consecutive columns.
+ * output_lead_dim: BLAS/LAPACK-side leading dimension, in elements
*/
typedef struct linearize_data_struct
{
@@ -513,19 +514,33 @@ typedef struct linearize_data_struct
npy_intp columns;
npy_intp row_strides;
npy_intp column_strides;
+ npy_intp output_lead_dim;
} LINEARIZE_DATA_t;
static NPY_INLINE void
+init_linearize_data_ex(LINEARIZE_DATA_t *lin_data,
+ npy_intp rows,
+ npy_intp columns,
+ npy_intp row_strides,
+ npy_intp column_strides,
+ npy_intp output_lead_dim)
+{
+ lin_data->rows = rows;
+ lin_data->columns = columns;
+ lin_data->row_strides = row_strides;
+ lin_data->column_strides = column_strides;
+ lin_data->output_lead_dim = output_lead_dim;
+}
+
+static NPY_INLINE void
init_linearize_data(LINEARIZE_DATA_t *lin_data,
npy_intp rows,
npy_intp columns,
npy_intp row_strides,
npy_intp column_strides)
{
- lin_data->rows = rows;
- lin_data->columns = columns;
- lin_data->row_strides = row_strides;
- lin_data->column_strides = column_strides;
+ init_linearize_data_ex(
+ lin_data, rows, columns, row_strides, column_strides, columns);
}
static NPY_INLINE void
@@ -860,7 +875,7 @@ linearize_@TYPE@_matrix(void *dst_in,
}
}
src += data->row_strides/sizeof(@typ@);
- dst += data->columns;
+ dst += data->output_lead_dim;
}
return rv;
} else {
@@ -907,7 +922,7 @@ delinearize_@TYPE@_matrix(void *dst_in,
sizeof(@typ@));
}
}
- src += data->columns;
+ src += data->output_lead_dim;
dst += data->row_strides/sizeof(@typ@);
}
@@ -3191,6 +3206,51 @@ release_@lapack_func@(GELSD_PARAMS_t* params)
memset(params, 0, sizeof(*params));
}
+static void
+@TYPE@_lstsq(char **args, npy_intp *dimensions, npy_intp *steps,
+ void *NPY_UNUSED(func))
+{
+ GELSD_PARAMS_t params;
+ int error_occurred = get_fp_invalid_and_clear();
+ fortran_int n, m, nrhs;
+ INIT_OUTER_LOOP_6
+
+ m = (fortran_int)dimensions[0];
+ n = (fortran_int)dimensions[1];
+ nrhs = (fortran_int)dimensions[2];
+
+ if (init_@lapack_func@(&params, m, n, nrhs)) {
+ LINEARIZE_DATA_t a_in, b_in, x_out, s_out;
+
+ init_linearize_data(&a_in, n, m, steps[1], steps[0]);
+ init_linearize_data_ex(&b_in, nrhs, m, steps[3], steps[2], fortran_int_max(n, m));
+ init_linearize_data(&x_out, nrhs, fortran_int_max(n, m), steps[5], steps[4]);
+ init_linearize_data(&s_out, 1, fortran_int_min(n, m), 1, steps[6]);
+
+ BEGIN_OUTER_LOOP_6
+ int not_ok;
+ linearize_@TYPE@_matrix(params.A, args[0], &a_in);
+ linearize_@TYPE@_matrix(params.B, args[1], &b_in);
+ params.RCOND = args[2];
+ not_ok = call_@lapack_func@(&params);
+ if (!not_ok) {
+ delinearize_@TYPE@_matrix(args[3], params.B, &x_out);
+ *(npy_int*) args[4] = params.RANK;
+ delinearize_@REALTYPE@_matrix(args[5], params.S, &s_out);
+ } else {
+ error_occurred = 1;
+ nan_@TYPE@_matrix(args[3], &x_out);
+ *(npy_int*) args[4] = -1;
+ nan_@REALTYPE@_matrix(args[5], &s_out);
+ }
+ END_OUTER_LOOP
+
+ release_@lapack_func@(&params);
+ }
+
+ set_fp_invalid_or_clear(error_occurred);
+}
+
/**end repeat**/
#pragma GCC diagnostic pop
@@ -3263,6 +3323,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(cholesky_lo);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A);
+GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq);
GUFUNC_FUNC_ARRAY_EIG(eig);
GUFUNC_FUNC_ARRAY_EIG(eigvals);
@@ -3328,6 +3389,14 @@ static char svd_1_3_types[] = {
NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE
};
+/* A, b, rcond, x, rank, s */
+static char lstsq_types[] = {
+ NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT,
+ NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
+ NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, NPY_INT, NPY_FLOAT,
+ NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_INT, NPY_DOUBLE
+};
+
typedef struct gufunc_descriptor_struct {
char *name;
char *signature;
@@ -3514,12 +3583,29 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = {
"eigvals",
"(m,m)->(m)",
"eigvals on the last two dimension and broadcast to the rest. \n"\
- "Results in a vector of eigenvalues. \n"\
- " \"(m,m)->(m),(m,m)\" \n",
+ "Results in a vector of eigenvalues. \n",
3, 1, 1,
FUNC_ARRAY_NAME(eigvals),
eigvals_types
},
+ {
+ "lstsq_m",
+ "(m,n),(m,nrhs),()->(n,nrhs),(),(m)",
+ "least squares on the last two dimensions and broadcast to the rest. \n"\
+ "For m <= n. \n",
+ 4, 3, 3,
+ FUNC_ARRAY_NAME(lstsq),
+ lstsq_types
+ },
+ {
+ "lstsq_n",
+ "(m,n),(m,nrhs),()->(m,nrhs),(),(n)",
+ "least squares on the last two dimensions and broadcast to the rest. \n"\
+ "For m >= n. \n",
+ 4, 3, 3,
+ FUNC_ARRAY_NAME(lstsq),
+ lstsq_types
+ }
};
static void