diff options
Diffstat (limited to 'numpy/linalg/umath_linalg.cpp')
-rw-r--r-- | numpy/linalg/umath_linalg.cpp | 81 |
1 files changed, 57 insertions, 24 deletions
diff --git a/numpy/linalg/umath_linalg.cpp b/numpy/linalg/umath_linalg.cpp index bbb4bb896..68db2b2f1 100644 --- a/numpy/linalg/umath_linalg.cpp +++ b/numpy/linalg/umath_linalg.cpp @@ -22,6 +22,7 @@ #include <cstdio> #include <cassert> #include <cmath> +#include <type_traits> #include <utility> @@ -1148,7 +1149,7 @@ slogdet(char **args, void *NPY_UNUSED(func)) { fortran_int m; - npy_uint8 *tmp_buff = NULL; + char *tmp_buff = NULL; size_t matrix_size; size_t pivot_size; size_t safe_m; @@ -1162,10 +1163,11 @@ slogdet(char **args, */ INIT_OUTER_LOOP_3 m = (fortran_int) dimensions[0]; - safe_m = m; + /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */ + safe_m = m != 0 ? m : 1; matrix_size = safe_m * safe_m * sizeof(typ); pivot_size = safe_m * sizeof(fortran_int); - tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size); + tmp_buff = (char *)malloc(matrix_size + pivot_size); if (tmp_buff) { LINEARIZE_DATA_t lin_data; @@ -1182,6 +1184,13 @@ slogdet(char **args, free(tmp_buff); } + else { + /* TODO: Requires use of new ufunc API to indicate error return */ + NPY_ALLOW_C_API_DEF + NPY_ALLOW_C_API; + PyErr_NoMemory(); + NPY_DISABLE_C_API; + } } template<typename typ, typename basetyp> @@ -1192,7 +1201,7 @@ det(char **args, void *NPY_UNUSED(func)) { fortran_int m; - npy_uint8 *tmp_buff; + char *tmp_buff; size_t matrix_size; size_t pivot_size; size_t safe_m; @@ -1206,10 +1215,11 @@ det(char **args, */ INIT_OUTER_LOOP_2 m = (fortran_int) dimensions[0]; - safe_m = m; + /* avoid empty malloc (buffers likely unused) and ensure m is `size_t` */ + safe_m = m != 0 ? m : 1; matrix_size = safe_m * safe_m * sizeof(typ); pivot_size = safe_m * sizeof(fortran_int); - tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size); + tmp_buff = (char *)malloc(matrix_size + pivot_size); if (tmp_buff) { LINEARIZE_DATA_t lin_data; @@ -1230,6 +1240,13 @@ det(char **args, free(tmp_buff); } + else { + /* TODO: Requires use of new ufunc API to indicate error return */ + NPY_ALLOW_C_API_DEF + NPY_ALLOW_C_API; + PyErr_NoMemory(); + NPY_DISABLE_C_API; + } } @@ -3737,16 +3754,16 @@ scalar_trait) fortran_int lda = fortran_int_max(1, m); fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); - mem_buff = (npy_uint8 *)malloc(a_size + b_size + s_size); - - if (!mem_buff) - goto error; + size_t msize = a_size + b_size + s_size; + mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1); + if (!mem_buff) { + goto no_memory; + } a = mem_buff; b = a + a_size; s = b + b_size; - params->M = m; params->N = n; params->NRHS = nrhs; @@ -3766,9 +3783,9 @@ scalar_trait) params->RWORK = NULL; params->LWORK = -1; - if (call_gelsd(params) != 0) + if (call_gelsd(params) != 0) { goto error; - + } work_count = (fortran_int)work_size_query; work_size = (size_t) work_size_query * sizeof(ftyp); @@ -3776,9 +3793,9 @@ scalar_trait) } mem_buff2 = (npy_uint8 *)malloc(work_size + iwork_size); - if (!mem_buff2) - goto error; - + if (!mem_buff2) { + goto no_memory; + } work = mem_buff2; iwork = work + work_size; @@ -3788,12 +3805,18 @@ scalar_trait) params->LWORK = work_count; return 1; + + no_memory: + NPY_ALLOW_C_API_DEF + NPY_ALLOW_C_API; + PyErr_NoMemory(); + NPY_DISABLE_C_API; + error: TRACE_TXT("%s failed init\n", __FUNCTION__); free(mem_buff); free(mem_buff2); memset(params, 0, sizeof(*params)); - return 0; } @@ -3857,16 +3880,17 @@ using frealtyp = basetype_t<ftyp>; fortran_int lda = fortran_int_max(1, m); fortran_int ldb = fortran_int_max(1, fortran_int_max(m,n)); - mem_buff = (npy_uint8 *)malloc(a_size + b_size + s_size); + size_t msize = a_size + b_size + s_size; + mem_buff = (npy_uint8 *)malloc(msize != 0 ? msize : 1); - if (!mem_buff) - goto error; + if (!mem_buff) { + goto no_memory; + } a = mem_buff; b = a + a_size; s = b + b_size; - params->M = m; params->N = n; params->NRHS = nrhs; @@ -3887,8 +3911,9 @@ using frealtyp = basetype_t<ftyp>; params->RWORK = &rwork_size_query; params->LWORK = -1; - if (call_gelsd(params) != 0) + if (call_gelsd(params) != 0) { goto error; + } work_count = (fortran_int)work_size_query.r; @@ -3898,8 +3923,9 @@ using frealtyp = basetype_t<ftyp>; } mem_buff2 = (npy_uint8 *)malloc(work_size + rwork_size + iwork_size); - if (!mem_buff2) - goto error; + if (!mem_buff2) { + goto no_memory; + } work = mem_buff2; rwork = work + work_size; @@ -3911,6 +3937,13 @@ using frealtyp = basetype_t<ftyp>; params->LWORK = work_count; return 1; + + no_memory: + NPY_ALLOW_C_API_DEF + NPY_ALLOW_C_API; + PyErr_NoMemory(); + NPY_DISABLE_C_API; + error: TRACE_TXT("%s failed init\n", __FUNCTION__); free(mem_buff); |