summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormattip <matti.picus@gmail.com>2021-11-01 09:15:18 +0200
committermattip <matti.picus@gmail.com>2021-11-02 11:35:33 +0200
commit0c992dca1cc23f12af14a8d66101166ef6c92355 (patch)
tree8e1cb673cfb1d71ca328d82a2c36d1ec5b929a99
parentf96aaa2ac484cab4e9f40b36902544e1174e0513 (diff)
downloadnumpy-0c992dca1cc23f12af14a8d66101166ef6c92355.tar.gz
BUG: fixes from review
-rw-r--r--numpy/core/src/multiarray/dlpack.c104
1 files changed, 64 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c
index 591eddfaf..f1591bb1f 100644
--- a/numpy/core/src/multiarray/dlpack.c
+++ b/numpy/core/src/multiarray/dlpack.c
@@ -22,7 +22,7 @@ array_dlpack_deleter(DLManagedTensor *self)
/* This is exactly as mandated by dlpack */
static void dlpack_capsule_deleter(PyObject *self) {
- if (PyCapsule_IsValid(self, "used_dltensor")) {
+ if (PyCapsule_IsValid(self, NPY_DLPACK_USED_CAPSULE_NAME)) {
return;
}
@@ -30,12 +30,16 @@ static void dlpack_capsule_deleter(PyObject *self) {
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
- DLManagedTensor *managed = (DLManagedTensor *)PyCapsule_GetPointer(self, "dltensor");
+ DLManagedTensor *managed =
+ (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_CAPSULE_NAME);
if (managed == NULL) {
PyErr_WriteUnraisable(self);
goto done;
}
- /* the spec says the deleter can be NULL if there is no way for the caller to provide a reasonable destructor. */
+ /*
+ * the spec says the deleter can be NULL if there is no way for the caller
+ * to provide a reasonable destructor.
+ */
if (managed->deleter) {
managed->deleter(managed);
/* TODO: is the deleter allowed to set a python exception? */
@@ -46,6 +50,34 @@ done:
PyErr_Restore(type, value, traceback);
}
+/* used internally, almost identical to dlpack_capsule_deleter() */
+static void array_dlpack_internal_capsule_deleter(PyObject *self)
+{
+ /* an exception may be in-flight, we must save it in case we create another one */
+ PyObject *type, *value, *traceback;
+ PyErr_Fetch(&type, &value, &traceback);
+
+ DLManagedTensor *managed =
+ (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_INTERNAL_CAPSULE_NAME);
+ if (managed == NULL) {
+ PyErr_WriteUnraisable(self);
+ goto done;
+ }
+ /*
+ * the spec says the deleter can be NULL if there is no way for the caller
+ * to provide a reasonable destructor.
+ */
+ if (managed->deleter) {
+ managed->deleter(managed);
+ /* TODO: is the deleter allowed to set a python exception? */
+ assert(!PyErr_Occurred());
+ }
+
+done:
+ PyErr_Restore(type, value, traceback);
+}
+
+
// This function cannot return NULL, but it can fail,
// So call PyErr_Occurred to check if it failed after
// calling it.
@@ -82,17 +114,6 @@ array_get_dl_data(PyArrayObject *self) {
return PyArray_DATA(self);
}
-/* used internally */
-static void array_dlpack_internal_capsule_deleter(PyObject *self)
-{
- DLManagedTensor *managed =
- (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_INTERNAL_CAPSULE_NAME);
- if (managed == NULL) {
- return;
- }
- managed->deleter(managed);
-}
-
PyObject *
array_dlpack(PyArrayObject *self,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
@@ -140,9 +161,11 @@ array_dlpack(PyArrayObject *self,
if (PyDataType_ISSIGNED(dtype)) {
managed_dtype.code = kDLInt;
- } else if (PyDataType_ISUNSIGNED(dtype)) {
+ }
+ else if (PyDataType_ISUNSIGNED(dtype)) {
managed_dtype.code = kDLUInt;
- } else if (PyDataType_ISFLOAT(dtype)) {
+ }
+ else if (PyDataType_ISFLOAT(dtype)) {
// We can't be sure that the dtype is
// IEEE or padded.
if (itemsize > 8) {
@@ -151,7 +174,8 @@ array_dlpack(PyArrayObject *self,
return NULL;
}
managed_dtype.code = kDLFloat;
- } else if (PyDataType_ISCOMPLEX(dtype)) {
+ }
+ else if (PyDataType_ISCOMPLEX(dtype)) {
// We can't be sure that the dtype is
// IEEE or padded.
if (itemsize > 16) {
@@ -160,7 +184,8 @@ array_dlpack(PyArrayObject *self,
return NULL;
}
managed_dtype.code = kDLComplex;
- } else {
+ }
+ else {
PyErr_SetString(PyExc_TypeError,
"DLPack only supports signed/unsigned integers, float "
"and complex dtypes.");
@@ -243,12 +268,12 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) {
return NULL;
}
- DLManagedTensor *managed =
+ DLManagedTensor *managed =
(DLManagedTensor *)PyCapsule_GetPointer(capsule,
NPY_DLPACK_CAPSULE_NAME);
if (managed == NULL) {
- Py_XDECREF(capsule);
+ Py_DECREF(capsule);
return NULL;
}
@@ -257,7 +282,7 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) {
PyErr_SetString(PyExc_RuntimeError,
"maxdims of DLPack tensor is higher than the supported "
"maxdims.");
- Py_XDECREF(capsule);
+ Py_DECREF(capsule);
return NULL;
}
@@ -268,14 +293,14 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) {
device_type != kDLCUDAManaged) {
PyErr_SetString(PyExc_RuntimeError,
"Unsupported device in DLTensor.");
- Py_XDECREF(capsule);
+ Py_DECREF(capsule);
return NULL;
}
if (managed->dl_tensor.dtype.lanes != 1) {
PyErr_SetString(PyExc_RuntimeError,
"Unsupported lanes in DLTensor dtype.");
- Py_XDECREF(capsule);
+ Py_DECREF(capsule);
return NULL;
}
@@ -321,13 +346,7 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) {
if (typenum == -1) {
PyErr_SetString(PyExc_RuntimeError,
"Unsupported dtype in DLTensor.");
- Py_XDECREF(capsule);
- return NULL;
- }
-
- PyArray_Descr *descr = PyArray_DescrFromType(typenum);
- if (descr == NULL) {
- Py_XDECREF(capsule);
+ Py_DECREF(capsule);
return NULL;
}
@@ -346,11 +365,16 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) {
char *data = (char *)managed->dl_tensor.data +
managed->dl_tensor.byte_offset;
+ PyArray_Descr *descr = PyArray_DescrFromType(typenum);
+ if (descr == NULL) {
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
PyObject *ret = PyArray_NewFromDescr(&PyArray_Type, descr, ndim, shape,
managed->dl_tensor.strides != NULL ? strides : NULL, data, 0, NULL);
if (ret == NULL) {
- Py_XDECREF(capsule);
- Py_XDECREF(descr);
+ Py_DECREF(capsule);
return NULL;
}
@@ -358,24 +382,24 @@ from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) {
NPY_DLPACK_INTERNAL_CAPSULE_NAME,
array_dlpack_internal_capsule_deleter);
if (new_capsule == NULL) {
- Py_XDECREF(capsule);
- Py_XDECREF(ret);
+ Py_DECREF(capsule);
+ Py_DECREF(ret);
return NULL;
}
if (PyArray_SetBaseObject((PyArrayObject *)ret, new_capsule) < 0) {
- Py_XDECREF(capsule);
- Py_XDECREF(ret);
+ Py_DECREF(capsule);
+ Py_DECREF(ret);
return NULL;
}
if (PyCapsule_SetName(capsule, NPY_DLPACK_USED_CAPSULE_NAME) < 0) {
- Py_XDECREF(capsule);
- Py_XDECREF(ret);
+ Py_DECREF(capsule);
+ Py_DECREF(ret);
return NULL;
}
-
- Py_XDECREF(capsule);
+
+ Py_DECREF(capsule);
return ret;
}