summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHameer Abbasi <einstein.edison@gmail.com>2021-05-31 09:46:20 +0200
committermattip <matti.picus@gmail.com>2021-11-02 11:34:40 +0200
commite167da747b53baed13a5148750c9c82746ffcb30 (patch)
tree72c15fbf7652257da23c72327c274108ee0dbd30
parente83b8d8044763d36c42d3ab103a5437893fb09d8 (diff)
downloadnumpy-e167da747b53baed13a5148750c9c82746ffcb30.tar.gz
MAINT: Robustify dlpack_capsule_deleter and add comments.
-rw-r--r--numpy/core/src/multiarray/methods.c33
-rw-r--r--numpy/core/tests/test_dlpack.py9
2 files changed, 34 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 29372fe2f..9a16adecd 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -2792,25 +2792,42 @@ array_dlpack_deleter(DLManagedTensor *self)
}
/* This is exactly as mandated by dlpack */
-static void array_dlpack_capsule_deleter(PyObject *self)
-{
- if (PyCapsule_IsValid(self, NPY_DLPACK_USED_CAPSULE_NAME)) {
+static void dlpack_capsule_deleter(PyObject *self){
+ if (PyCapsule_IsValid(self, "used_dltensor")) {
return;
}
- DLManagedTensor *managed =
- (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_CAPSULE_NAME);
+
+ /* an exception may be in-flight, we must save it in case we create another one */
+ PyObject *type, *value, *traceback;
+ PyErr_Fetch(&type, &value, &traceback);
+
+ DLManagedTensor *managed = (DLManagedTensor *)PyCapsule_GetPointer(self, "dltensor");
if (managed == NULL) {
- return;
+ PyErr_WriteUnraisable(self);
+ goto done;
}
- managed->deleter(managed);
+ /* the spec says the deleter can be NULL if there is no way for the caller to provide a reasonable destructor. */
+ if (managed->deleter) {
+ managed->deleter(managed);
+ /* TODO: is the deleter allowed to set a python exception? */
+ assert(!PyErr_Occurred());
+ }
+
+done:
+ PyErr_Restore(type, value, traceback);
}
+// This function cannot return NULL, but it can fail,
+// So call PyErr_Occurred to check if it failed after
+// calling it.
static DLDevice
array_get_dl_device(PyArrayObject *self) {
DLDevice ret;
ret.device_type = kDLCPU;
ret.device_id = 0;
PyObject *base = PyArray_BASE(self);
+ // The outer if is due to the fact that NumPy arrays are on the CPU
+ // by default.
if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) {
DLManagedTensor *managed = PyCapsule_GetPointer(
base, NPY_DLPACK_INTERNAL_CAPSULE_NAME);
@@ -2937,7 +2954,7 @@ array_dlpack(PyArrayObject *self,
managed->deleter = array_dlpack_deleter;
PyObject *capsule = PyCapsule_New(managed, NPY_DLPACK_CAPSULE_NAME,
- array_dlpack_capsule_deleter);
+ dlpack_capsule_deleter);
if (capsule == NULL) {
PyMem_Free(managed);
PyMem_Free(managed_shape_strides);
diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py
index 926668c59..2561991e8 100644
--- a/numpy/core/tests/test_dlpack.py
+++ b/numpy/core/tests/test_dlpack.py
@@ -92,3 +92,12 @@ class TestDLPack:
x = np.arange(5)
assert x.__dlpack_device__() == (1, 0)
assert np.from_dlpack(x).__dlpack_device__() == (1, 0)
+
+ def dlpack_deleter_exception(self):
+ x = np.arange(5)
+ _ = x.__dlpack__()
+ raise RuntimeError
+
+ def test_dlpack_destructor_exception(self):
+ with pytest.raises(RuntimeError):
+ self.dlpack_deleter_exception()