summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTirth Patel <tirthasheshpatel@gmail.com>2022-02-25 12:35:04 +0530
committerTirth Patel <tirthasheshpatel@gmail.com>2022-02-25 12:35:04 +0530
commit25b3def8aec6eccc6e43a18746d07a9cafbe7449 (patch)
tree70438972ba41563924967d922392e733a6ab9eff /numpy
parent41d37b714caa1eef72f984d529f1d40ed48ce535 (diff)
downloadnumpy-25b3def8aec6eccc6e43a18746d07a9cafbe7449.tar.gz
BUG, ENH: np._from_dlpack: export correct device information
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/dlpack.c6
-rw-r--r--numpy/core/tests/test_dlpack.py5
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c
index 291e60a22..8491ed5b9 100644
--- a/numpy/core/src/multiarray/dlpack.c
+++ b/numpy/core/src/multiarray/dlpack.c
@@ -88,6 +88,12 @@ array_get_dl_device(PyArrayObject *self) {
ret.device_type = kDLCPU;
ret.device_id = 0;
PyObject *base = PyArray_BASE(self);
+
+ // walk the bases (see gh-20340)
+ while (base != NULL && PyArray_Check(base)) {
+ base = PyArray_BASE((PyArrayObject *)base);
+ }
+
// The outer if is due to the fact that NumPy arrays are on the CPU
// by default (if not created from DLPack).
if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) {
diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py
index f848b2008..2ab55e903 100644
--- a/numpy/core/tests/test_dlpack.py
+++ b/numpy/core/tests/test_dlpack.py
@@ -91,7 +91,10 @@ class TestDLPack:
def test_dlpack_device(self):
x = np.arange(5)
assert x.__dlpack_device__() == (1, 0)
- assert np._from_dlpack(x).__dlpack_device__() == (1, 0)
+ y = np._from_dlpack(x)
+ assert y.__dlpack_device__() == (1, 0)
+ z = y[::2]
+ assert z.__dlpack_device__() == (1, 0)
def dlpack_deleter_exception(self):
x = np.arange(5)