diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c index 291e60a226a7..8491ed5b9bde 100644 --- a/numpy/core/src/multiarray/dlpack.c +++ b/numpy/core/src/multiarray/dlpack.c @@ -88,6 +88,12 @@ array_get_dl_device(PyArrayObject *self) { ret.device_type = kDLCPU; ret.device_id = 0; PyObject *base = PyArray_BASE(self); + + // walk the bases (see gh-20340) + while (base != NULL && PyArray_Check(base)) { + base = PyArray_BASE((PyArrayObject *)base); + } + // The outer if is due to the fact that NumPy arrays are on the CPU // by default (if not created from DLPack). if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) { diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py index f848b2008cf9..2ab55e903861 100644 --- a/numpy/core/tests/test_dlpack.py +++ b/numpy/core/tests/test_dlpack.py @@ -91,7 +91,10 @@ def test_higher_dims(self, ndim): def test_dlpack_device(self): x = np.arange(5) assert x.__dlpack_device__() == (1, 0) - assert np._from_dlpack(x).__dlpack_device__() == (1, 0) + y = np._from_dlpack(x) + assert y.__dlpack_device__() == (1, 0) + z = y[::2] + assert z.__dlpack_device__() == (1, 0) def dlpack_deleter_exception(self): x = np.arange(5)