summaryrefslogtreecommitdiff
path: root/numpy/linalg/umath_linalg.c.src
diff options
context:
space:
mode:
authorczgdp1807 <gdp.1807@gmail.com>2021-06-05 09:00:45 +0530
committerczgdp1807 <gdp.1807@gmail.com>2021-06-05 10:55:37 +0530
commitc33cf14e144f8f3f8c226e81bbf693cdba7b3187 (patch)
tree215027962e21e793d85993d02eb66f16711103b1 /numpy/linalg/umath_linalg.c.src
parent885ba91db5132d3a4b1dfd5f60234bfedb08b432 (diff)
downloadnumpy-c33cf14e144f8f3f8c226e81bbf693cdba7b3187.tar.gz
existing tests passed
Diffstat (limited to 'numpy/linalg/umath_linalg.c.src')
-rw-r--r--numpy/linalg/umath_linalg.c.src62
1 files changed, 41 insertions, 21 deletions
diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src
index a2f463d24..1f794a1e3 100644
--- a/numpy/linalg/umath_linalg.c.src
+++ b/numpy/linalg/umath_linalg.c.src
@@ -3372,6 +3372,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
{
/* compute optimal work size */
/* IS THE FOLLOWING METHOD CORRECT? */
+
@ftyp@ work_size_query;
params->WORK = &work_size_query;
@@ -3380,11 +3381,14 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
if (call_@lapack_func@(params) != 0)
goto error;
- work_count = (fortran_int)work_size_query;
+ work_count = (fortran_int) *(@ftyp@*) params->WORK;
- work_size = (size_t) work_size_query * sizeof(@ftyp@);
}
+ params->LWORK = fortran_int_max(fortran_int_max(1, n),
+ work_count > 0 ? work_count : -work_count);
+
+ work_size = (size_t) params->LWORK * sizeof(@ftyp@);
mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;
@@ -3393,7 +3397,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
memset(work, 0, work_size);
params->WORK = work;
- params->LWORK = work_count;
return 1;
error:
@@ -3451,6 +3454,7 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
{
/* compute optimal work size */
/* IS THE FOLLOWING METHOD CORRECT? */
+
@ftyp@ work_size_query;
params->WORK = &work_size_query;
@@ -3459,11 +3463,15 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
if (call_@lapack_func@(params) != 0)
goto error;
- work_count = (fortran_int)work_size_query.r;
+ work_count = (fortran_int) ((@ftyp@*)params->WORK)->r;
- work_size = (size_t) work_size_query.r * sizeof(@ftyp@);
}
+ params->LWORK = fortran_int_max(fortran_int_max(1, n),
+ work_count > 0 ? work_count : -work_count);
+
+ work_size = (size_t) params->LWORK * sizeof(@ftyp@);
+
mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;
@@ -3472,7 +3480,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
memset(work, 0, work_size);
params->WORK = work;
- params->LWORK = work_count;
return 1;
error:
@@ -3631,7 +3638,6 @@ init_@lapack_func@(GQR_PARAMS_t *params,
size_t safe_min_m_n = min_m_n;
size_t safe_m = m;
size_t safe_n = n;
-
size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@);
size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
@@ -3669,11 +3675,15 @@ init_@lapack_func@(GQR_PARAMS_t *params,
if (call_@lapack_func@(params) != 0)
goto error;
- work_count = (fortran_int)work_size_query;
+ work_count = (fortran_int) *(@ftyp@*) params->WORK;
- work_size = (size_t) work_size_query * sizeof(@ftyp@);
}
+ params->LWORK = fortran_int_max(fortran_int_max(1, n),
+ work_count > 0 ? work_count : -work_count);
+
+ work_size = (size_t) params->LWORK * sizeof(@ftyp@);
+
mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;
@@ -3682,7 +3692,6 @@ init_@lapack_func@(GQR_PARAMS_t *params,
memset(work, 0, work_size);
params->WORK = work;
- params->LWORK = work_count;
return 1;
error:
@@ -3751,11 +3760,15 @@ init_@lapack_func@(GQR_PARAMS_t *params,
if (call_@lapack_func@(params) != 0)
goto error;
- work_count = (fortran_int)work_size_query.r;
+ work_count = (fortran_int) ((@ftyp@*)params->WORK)->r;
- work_size = (size_t) work_size_query.r * sizeof(@ftyp@);
}
+ params->LWORK = fortran_int_max(fortran_int_max(1, n),
+ work_count > 0 ? work_count : -work_count);
+
+ work_size = (size_t) params->LWORK * sizeof(@ftyp@);
+
mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;
@@ -3862,6 +3875,7 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params,
fortran_int m,
fortran_int n)
{
+ printf("Inside zungqr\n");
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *q, *tau, *work;
@@ -3871,7 +3885,7 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params,
size_t safe_n = n;
size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
- size_t q_size = safe_m * safe_min_m_n * sizeof(@ftyp@);
+ size_t q_size = safe_m * safe_m * sizeof(@ftyp@);
size_t tau_size = safe_min_m_n * sizeof(@ftyp@);
fortran_int work_count;
@@ -3889,7 +3903,7 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params,
params->M = m;
- params->MC = min_m_n;
+ params->MC = m;
params->MN = min_m_n;
params->A = a;
params->Q = q;
@@ -3907,11 +3921,15 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params,
if (call_@lapack_func@(params) != 0)
goto error;
- work_count = (fortran_int)work_size_query;
+ work_count = (fortran_int) *(@ftyp@*) params->WORK;
- work_size = (size_t) work_size_query * sizeof(@ftyp@);
}
+ params->LWORK = fortran_int_max(fortran_int_max(1, n),
+ work_count > 0 ? work_count : -work_count);
+
+ work_size = (size_t) params->LWORK * sizeof(@ftyp@);
+
mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;
@@ -3920,7 +3938,6 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params,
memset(work, 0, work_size);
params->WORK = work;
- params->LWORK = work_count;
return 1;
error:
@@ -3941,8 +3958,8 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params,
*/
static inline int
init_@lapack_func@_complete(GQR_PARAMS_t *params,
- fortran_int m,
- fortran_int n)
+ fortran_int m,
+ fortran_int n)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
@@ -3989,11 +4006,14 @@ init_@lapack_func@_complete(GQR_PARAMS_t *params,
if (call_@lapack_func@(params) != 0)
goto error;
- work_count = (fortran_int)work_size_query.r;
+ work_count = (fortran_int) ((@ftyp@*)params->WORK)->r;
- work_size = (size_t) work_size_query.r * sizeof(@ftyp@);
}
+ params->LWORK = fortran_int_max(fortran_int_max(1, n),
+ work_count > 0 ? work_count : -work_count);
+
+ work_size = (size_t) params->LWORK * sizeof(@ftyp@);
mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;