diff options
| -rw-r--r-- | numpy/core/src/multiarray/dlpack.c | 6 | ||||
| -rw-r--r-- | numpy/core/tests/test_dlpack.py | 5 |
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c index 291e60a22..8491ed5b9 100644 --- a/numpy/core/src/multiarray/dlpack.c +++ b/numpy/core/src/multiarray/dlpack.c @@ -88,6 +88,12 @@ array_get_dl_device(PyArrayObject *self) { ret.device_type = kDLCPU; ret.device_id = 0; PyObject *base = PyArray_BASE(self); + + // walk the bases (see gh-20340) + while (base != NULL && PyArray_Check(base)) { + base = PyArray_BASE((PyArrayObject *)base); + } + // The outer if is due to the fact that NumPy arrays are on the CPU // by default (if not created from DLPack). if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) { diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py index f848b2008..2ab55e903 100644 --- a/numpy/core/tests/test_dlpack.py +++ b/numpy/core/tests/test_dlpack.py @@ -91,7 +91,10 @@ class TestDLPack: def test_dlpack_device(self): x = np.arange(5) assert x.__dlpack_device__() == (1, 0) - assert np._from_dlpack(x).__dlpack_device__() == (1, 0) + y = np._from_dlpack(x) + assert y.__dlpack_device__() == (1, 0) + z = y[::2] + assert z.__dlpack_device__() == (1, 0) def dlpack_deleter_exception(self): x = np.arange(5) |
