summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHameer Abbasi <einstein.edison@gmail.com>2021-05-30 11:57:46 +0200
committermattip <matti.picus@gmail.com>2021-11-02 11:34:40 +0200
commite83b8d8044763d36c42d3ab103a5437893fb09d8 (patch)
tree0cf869d5d99a70a77b366a2320dc7a41c0e129ed
parent9ebee26c6c5cd89623d531608eed25a770d01fff (diff)
downloadnumpy-e83b8d8044763d36c42d3ab103a5437893fb09d8.tar.gz
BUG, TST: Device bugfix and test __dl_device__.
-rw-r--r--numpy/__init__.pyi2
-rw-r--r--numpy/core/src/multiarray/methods.c39
-rw-r--r--numpy/core/tests/test_dlpack.py5
3 files changed, 32 insertions, 14 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index c808f0baf..63e723a35 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -2449,7 +2449,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
@overload
def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
@overload
- def __dlpack_device__(self) -> Tuple[L[1], L[0]]: ...
+ def __dlpack_device__(self) -> Tuple[int, L[0]]: ...
# Keep `dtype` at the bottom to avoid name conflicts with `np.dtype`
@property
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 42464014c..29372fe2f 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -2805,6 +2805,23 @@ static void array_dlpack_capsule_deleter(PyObject *self)
managed->deleter(managed);
}
+static DLDevice
+array_get_dl_device(PyArrayObject *self) {
+ DLDevice ret;
+ ret.device_type = kDLCPU;
+ ret.device_id = 0;
+ PyObject *base = PyArray_BASE(self);
+ if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) {
+ DLManagedTensor *managed = PyCapsule_GetPointer(
+ base, NPY_DLPACK_INTERNAL_CAPSULE_NAME);
+ if (managed == NULL) {
+ return ret;
+ }
+ return managed->dl_tensor.device;
+ }
+ return ret;
+}
+
static PyObject *
array_dlpack(PyArrayObject *self,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
@@ -2886,8 +2903,11 @@ array_dlpack(PyArrayObject *self,
}
managed->dl_tensor.data = PyArray_DATA(self);
- managed->dl_tensor.device.device_type = kDLCPU;
- managed->dl_tensor.device.device_id = 0;
+ managed->dl_tensor.device = array_get_dl_device(self);
+ if (PyErr_Occurred()) {
+ PyMem_Free(managed);
+ return NULL;
+ }
managed->dl_tensor.dtype = managed_dtype;
@@ -2932,18 +2952,11 @@ array_dlpack(PyArrayObject *self,
static PyObject *
array_dlpack_device(PyArrayObject *self, PyObject *NPY_UNUSED(args))
{
- PyObject *base = PyArray_BASE(self);
- if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) {
- DLManagedTensor *managed = PyCapsule_GetPointer(base,
- NPY_DLPACK_INTERNAL_CAPSULE_NAME);
- if (managed == NULL) {
- return NULL;
- }
- return Py_BuildValue("ii", managed->dl_tensor.device.device_type,
- managed->dl_tensor.device.device_id);
+ DLDevice device = array_get_dl_device(self);
+ if (PyErr_Occurred()) {
+ return NULL;
}
- return Py_BuildValue("ii", kDLCPU, 0);
->>>>>>> ENH: Add the __dlpack__ and __dlpack_device__ methods to ndarray.
+ return Py_BuildValue("ii", device.device_type, device.device_id);
}
NPY_NO_EXPORT PyMethodDef array_methods[] = {
diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py
index 19f4e9281..926668c59 100644
--- a/numpy/core/tests/test_dlpack.py
+++ b/numpy/core/tests/test_dlpack.py
@@ -87,3 +87,8 @@ class TestDLPack:
x = np.zeros(shape, dtype=np.float64)
assert shape == np.from_dlpack(x).shape
+
+ def test_dlpack_device(self):
+ x = np.arange(5)
+ assert x.__dlpack_device__() == (1, 0)
+ assert np.from_dlpack(x).__dlpack_device__() == (1, 0)