summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/dlpack.c6
-rw-r--r--numpy/core/tests/test_dlpack.py5
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c
index 291e60a22..8491ed5b9 100644
--- a/numpy/core/src/multiarray/dlpack.c
+++ b/numpy/core/src/multiarray/dlpack.c
@@ -88,6 +88,12 @@ array_get_dl_device(PyArrayObject *self) {
ret.device_type = kDLCPU;
ret.device_id = 0;
PyObject *base = PyArray_BASE(self);
+
+ // walk the bases (see gh-20340)
+ while (base != NULL && PyArray_Check(base)) {
+ base = PyArray_BASE((PyArrayObject *)base);
+ }
+
// The outer if is due to the fact that NumPy arrays are on the CPU
// by default (if not created from DLPack).
if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) {
diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py
index f848b2008..2ab55e903 100644
--- a/numpy/core/tests/test_dlpack.py
+++ b/numpy/core/tests/test_dlpack.py
@@ -91,7 +91,10 @@ class TestDLPack:
def test_dlpack_device(self):
x = np.arange(5)
assert x.__dlpack_device__() == (1, 0)
- assert np._from_dlpack(x).__dlpack_device__() == (1, 0)
+ y = np._from_dlpack(x)
+ assert y.__dlpack_device__() == (1, 0)
+ z = y[::2]
+ assert z.__dlpack_device__() == (1, 0)
def dlpack_deleter_exception(self):
x = np.arange(5)