diff options
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 33 | ||||
-rw-r--r-- | numpy/core/tests/test_dlpack.py | 9 |
2 files changed, 34 insertions, 8 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 29372fe2f..9a16adecd 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2792,25 +2792,42 @@ array_dlpack_deleter(DLManagedTensor *self) } /* This is exactly as mandated by dlpack */ -static void array_dlpack_capsule_deleter(PyObject *self) -{ - if (PyCapsule_IsValid(self, NPY_DLPACK_USED_CAPSULE_NAME)) { +static void dlpack_capsule_deleter(PyObject *self){ + if (PyCapsule_IsValid(self, "used_dltensor")) { return; } - DLManagedTensor *managed = - (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_CAPSULE_NAME); + + /* an exception may be in-flight, we must save it in case we create another one */ + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + + DLManagedTensor *managed = (DLManagedTensor *)PyCapsule_GetPointer(self, "dltensor"); if (managed == NULL) { - return; + PyErr_WriteUnraisable(self); + goto done; } - managed->deleter(managed); + /* the spec says the deleter can be NULL if there is no way for the caller to provide a reasonable destructor. */ + if (managed->deleter) { + managed->deleter(managed); + /* TODO: is the deleter allowed to set a python exception? */ + assert(!PyErr_Occurred()); + } + +done: + PyErr_Restore(type, value, traceback); } +// This function cannot return NULL, but it can fail, +// So call PyErr_Occurred to check if it failed after +// calling it. static DLDevice array_get_dl_device(PyArrayObject *self) { DLDevice ret; ret.device_type = kDLCPU; ret.device_id = 0; PyObject *base = PyArray_BASE(self); + // The outer if is due to the fact that NumPy arrays are on the CPU + // by default. if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) { DLManagedTensor *managed = PyCapsule_GetPointer( base, NPY_DLPACK_INTERNAL_CAPSULE_NAME); @@ -2937,7 +2954,7 @@ array_dlpack(PyArrayObject *self, managed->deleter = array_dlpack_deleter; PyObject *capsule = PyCapsule_New(managed, NPY_DLPACK_CAPSULE_NAME, - array_dlpack_capsule_deleter); + dlpack_capsule_deleter); if (capsule == NULL) { PyMem_Free(managed); PyMem_Free(managed_shape_strides); diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py index 926668c59..2561991e8 100644 --- a/numpy/core/tests/test_dlpack.py +++ b/numpy/core/tests/test_dlpack.py @@ -92,3 +92,12 @@ class TestDLPack: x = np.arange(5) assert x.__dlpack_device__() == (1, 0) assert np.from_dlpack(x).__dlpack_device__() == (1, 0) + + def dlpack_deleter_exception(self): + x = np.arange(5) + _ = x.__dlpack__() + raise RuntimeError + + def test_dlpack_destructor_exception(self): + with pytest.raises(RuntimeError): + self.dlpack_deleter_exception() |