diff options
| author | Tirth Patel <tirthasheshpatel@gmail.com> | 2022-02-25 12:35:04 +0530 |
|---|---|---|
| committer | Tirth Patel <tirthasheshpatel@gmail.com> | 2022-02-25 12:35:04 +0530 |
| commit | 25b3def8aec6eccc6e43a18746d07a9cafbe7449 (patch) | |
| tree | 70438972ba41563924967d922392e733a6ab9eff /numpy | |
| parent | 41d37b714caa1eef72f984d529f1d40ed48ce535 (diff) | |
| download | numpy-25b3def8aec6eccc6e43a18746d07a9cafbe7449.tar.gz | |
BUG, ENH: np._from_dlpack: export correct device information
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/multiarray/dlpack.c | 6 | ||||
| -rw-r--r-- | numpy/core/tests/test_dlpack.py | 5 |
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c index 291e60a22..8491ed5b9 100644 --- a/numpy/core/src/multiarray/dlpack.c +++ b/numpy/core/src/multiarray/dlpack.c @@ -88,6 +88,12 @@ array_get_dl_device(PyArrayObject *self) { ret.device_type = kDLCPU; ret.device_id = 0; PyObject *base = PyArray_BASE(self); + + // walk the bases (see gh-20340) + while (base != NULL && PyArray_Check(base)) { + base = PyArray_BASE((PyArrayObject *)base); + } + // The outer if is due to the fact that NumPy arrays are on the CPU // by default (if not created from DLPack). if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) { diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py index f848b2008..2ab55e903 100644 --- a/numpy/core/tests/test_dlpack.py +++ b/numpy/core/tests/test_dlpack.py @@ -91,7 +91,10 @@ class TestDLPack: def test_dlpack_device(self): x = np.arange(5) assert x.__dlpack_device__() == (1, 0) - assert np._from_dlpack(x).__dlpack_device__() == (1, 0) + y = np._from_dlpack(x) + assert y.__dlpack_device__() == (1, 0) + z = y[::2] + assert z.__dlpack_device__() == (1, 0) def dlpack_deleter_exception(self): x = np.arange(5) |
