diff options
author | mattip <matti.picus@gmail.com> | 2021-11-01 09:15:18 +0200 |
---|---|---|
committer | mattip <matti.picus@gmail.com> | 2021-11-02 11:35:33 +0200 |
commit | 0c992dca1cc23f12af14a8d66101166ef6c92355 (patch) | |
tree | 8e1cb673cfb1d71ca328d82a2c36d1ec5b929a99 | |
parent | f96aaa2ac484cab4e9f40b36902544e1174e0513 (diff) | |
download | numpy-0c992dca1cc23f12af14a8d66101166ef6c92355.tar.gz |
BUG: fixes from review
-rw-r--r-- | numpy/core/src/multiarray/dlpack.c | 104 |
1 files changed, 64 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c index 591eddfaf..f1591bb1f 100644 --- a/numpy/core/src/multiarray/dlpack.c +++ b/numpy/core/src/multiarray/dlpack.c @@ -22,7 +22,7 @@ array_dlpack_deleter(DLManagedTensor *self) /* This is exactly as mandated by dlpack */ static void dlpack_capsule_deleter(PyObject *self) { - if (PyCapsule_IsValid(self, "used_dltensor")) { + if (PyCapsule_IsValid(self, NPY_DLPACK_USED_CAPSULE_NAME)) { return; } @@ -30,12 +30,16 @@ static void dlpack_capsule_deleter(PyObject *self) { PyObject *type, *value, *traceback; PyErr_Fetch(&type, &value, &traceback); - DLManagedTensor *managed = (DLManagedTensor *)PyCapsule_GetPointer(self, "dltensor"); + DLManagedTensor *managed = + (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_CAPSULE_NAME); if (managed == NULL) { PyErr_WriteUnraisable(self); goto done; } - /* the spec says the deleter can be NULL if there is no way for the caller to provide a reasonable destructor. */ + /* + * the spec says the deleter can be NULL if there is no way for the caller + * to provide a reasonable destructor. + */ if (managed->deleter) { managed->deleter(managed); /* TODO: is the deleter allowed to set a python exception? */ @@ -46,6 +50,34 @@ done: PyErr_Restore(type, value, traceback); } +/* used internally, almost identical to dlpack_capsule_deleter() */ +static void array_dlpack_internal_capsule_deleter(PyObject *self) +{ + /* an exception may be in-flight, we must save it in case we create another one */ + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + + DLManagedTensor *managed = + (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_INTERNAL_CAPSULE_NAME); + if (managed == NULL) { + PyErr_WriteUnraisable(self); + goto done; + } + /* + * the spec says the deleter can be NULL if there is no way for the caller + * to provide a reasonable destructor. + */ + if (managed->deleter) { + managed->deleter(managed); + /* TODO: is the deleter allowed to set a python exception? */ + assert(!PyErr_Occurred()); + } + +done: + PyErr_Restore(type, value, traceback); +} + + // This function cannot return NULL, but it can fail, // So call PyErr_Occurred to check if it failed after // calling it. @@ -82,17 +114,6 @@ array_get_dl_data(PyArrayObject *self) { return PyArray_DATA(self); } -/* used internally */ -static void array_dlpack_internal_capsule_deleter(PyObject *self) -{ - DLManagedTensor *managed = - (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_INTERNAL_CAPSULE_NAME); - if (managed == NULL) { - return; - } - managed->deleter(managed); -} - PyObject * array_dlpack(PyArrayObject *self, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) @@ -140,9 +161,11 @@ array_dlpack(PyArrayObject *self, if (PyDataType_ISSIGNED(dtype)) { managed_dtype.code = kDLInt; - } else if (PyDataType_ISUNSIGNED(dtype)) { + } + else if (PyDataType_ISUNSIGNED(dtype)) { managed_dtype.code = kDLUInt; - } else if (PyDataType_ISFLOAT(dtype)) { + } + else if (PyDataType_ISFLOAT(dtype)) { // We can't be sure that the dtype is // IEEE or padded. if (itemsize > 8) { @@ -151,7 +174,8 @@ array_dlpack(PyArrayObject *self, return NULL; } managed_dtype.code = kDLFloat; - } else if (PyDataType_ISCOMPLEX(dtype)) { + } + else if (PyDataType_ISCOMPLEX(dtype)) { // We can't be sure that the dtype is // IEEE or padded. if (itemsize > 16) { @@ -160,7 +184,8 @@ array_dlpack(PyArrayObject *self, return NULL; } managed_dtype.code = kDLComplex; - } else { + } + else { PyErr_SetString(PyExc_TypeError, "DLPack only supports signed/unsigned integers, float " "and complex dtypes."); @@ -243,12 +268,12 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { return NULL; } - DLManagedTensor *managed = + DLManagedTensor *managed = (DLManagedTensor *)PyCapsule_GetPointer(capsule, NPY_DLPACK_CAPSULE_NAME); if (managed == NULL) { - Py_XDECREF(capsule); + Py_DECREF(capsule); return NULL; } @@ -257,7 +282,7 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { PyErr_SetString(PyExc_RuntimeError, "maxdims of DLPack tensor is higher than the supported " "maxdims."); - Py_XDECREF(capsule); + Py_DECREF(capsule); return NULL; } @@ -268,14 +293,14 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { device_type != kDLCUDAManaged) { PyErr_SetString(PyExc_RuntimeError, "Unsupported device in DLTensor."); - Py_XDECREF(capsule); + Py_DECREF(capsule); return NULL; } if (managed->dl_tensor.dtype.lanes != 1) { PyErr_SetString(PyExc_RuntimeError, "Unsupported lanes in DLTensor dtype."); - Py_XDECREF(capsule); + Py_DECREF(capsule); return NULL; } @@ -321,13 +346,7 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { if (typenum == -1) { PyErr_SetString(PyExc_RuntimeError, "Unsupported dtype in DLTensor."); - Py_XDECREF(capsule); - return NULL; - } - - PyArray_Descr *descr = PyArray_DescrFromType(typenum); - if (descr == NULL) { - Py_XDECREF(capsule); + Py_DECREF(capsule); return NULL; } @@ -346,11 +365,16 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { char *data = (char *)managed->dl_tensor.data + managed->dl_tensor.byte_offset; + PyArray_Descr *descr = PyArray_DescrFromType(typenum); + if (descr == NULL) { + Py_DECREF(capsule); + return NULL; + } + PyObject *ret = PyArray_NewFromDescr(&PyArray_Type, descr, ndim, shape, managed->dl_tensor.strides != NULL ? strides : NULL, data, 0, NULL); if (ret == NULL) { - Py_XDECREF(capsule); - Py_XDECREF(descr); + Py_DECREF(capsule); return NULL; } @@ -358,24 +382,24 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) { NPY_DLPACK_INTERNAL_CAPSULE_NAME, array_dlpack_internal_capsule_deleter); if (new_capsule == NULL) { - Py_XDECREF(capsule); - Py_XDECREF(ret); + Py_DECREF(capsule); + Py_DECREF(ret); return NULL; } if (PyArray_SetBaseObject((PyArrayObject *)ret, new_capsule) < 0) { - Py_XDECREF(capsule); - Py_XDECREF(ret); + Py_DECREF(capsule); + Py_DECREF(ret); return NULL; } if (PyCapsule_SetName(capsule, NPY_DLPACK_USED_CAPSULE_NAME) < 0) { - Py_XDECREF(capsule); - Py_XDECREF(ret); + Py_DECREF(capsule); + Py_DECREF(ret); return NULL; } - - Py_XDECREF(capsule); + + Py_DECREF(capsule); return ret; } |