diff options
| -rw-r--r-- | numpy/__init__.pyi | 2 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/methods.c | 39 | ||||
| -rw-r--r-- | numpy/core/tests/test_dlpack.py | 5 |
3 files changed, 32 insertions, 14 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index c808f0baf..63e723a35 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -2449,7 +2449,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): @overload def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ... @overload - def __dlpack_device__(self) -> Tuple[L[1], L[0]]: ... + def __dlpack_device__(self) -> Tuple[int, L[0]]: ... # Keep `dtype` at the bottom to avoid name conflicts with `np.dtype` @property diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 42464014c..29372fe2f 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -2805,6 +2805,23 @@ static void array_dlpack_capsule_deleter(PyObject *self) managed->deleter(managed); } +static DLDevice +array_get_dl_device(PyArrayObject *self) { + DLDevice ret; + ret.device_type = kDLCPU; + ret.device_id = 0; + PyObject *base = PyArray_BASE(self); + if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) { + DLManagedTensor *managed = PyCapsule_GetPointer( + base, NPY_DLPACK_INTERNAL_CAPSULE_NAME); + if (managed == NULL) { + return ret; + } + return managed->dl_tensor.device; + } + return ret; +} + static PyObject * array_dlpack(PyArrayObject *self, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) @@ -2886,8 +2903,11 @@ array_dlpack(PyArrayObject *self, } managed->dl_tensor.data = PyArray_DATA(self); - managed->dl_tensor.device.device_type = kDLCPU; - managed->dl_tensor.device.device_id = 0; + managed->dl_tensor.device = array_get_dl_device(self); + if (PyErr_Occurred()) { + PyMem_Free(managed); + return NULL; + } managed->dl_tensor.dtype = managed_dtype; @@ -2932,18 +2952,11 @@ array_dlpack(PyArrayObject *self, static PyObject * array_dlpack_device(PyArrayObject *self, PyObject *NPY_UNUSED(args)) { - PyObject *base = PyArray_BASE(self); - if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) { - DLManagedTensor *managed = PyCapsule_GetPointer(base, - NPY_DLPACK_INTERNAL_CAPSULE_NAME); - if (managed == NULL) { - return NULL; - } - return Py_BuildValue("ii", managed->dl_tensor.device.device_type, - managed->dl_tensor.device.device_id); + DLDevice device = array_get_dl_device(self); + if (PyErr_Occurred()) { + return NULL; } - return Py_BuildValue("ii", kDLCPU, 0); ->>>>>>> ENH: Add the __dlpack__ and __dlpack_device__ methods to ndarray. + return Py_BuildValue("ii", device.device_type, device.device_id); } NPY_NO_EXPORT PyMethodDef array_methods[] = { diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py index 19f4e9281..926668c59 100644 --- a/numpy/core/tests/test_dlpack.py +++ b/numpy/core/tests/test_dlpack.py @@ -87,3 +87,8 @@ class TestDLPack: x = np.zeros(shape, dtype=np.float64) assert shape == np.from_dlpack(x).shape + + def test_dlpack_device(self): + x = np.arange(5) + assert x.__dlpack_device__() == (1, 0) + assert np.from_dlpack(x).__dlpack_device__() == (1, 0) |
