diff options
| -rw-r--r-- | numpy/core/src/multiarray/dlpack.c | 23 | ||||
| -rw-r--r-- | numpy/core/tests/test_dlpack.py | 5 |
2 files changed, 16 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c index 8491ed5b9..980b85395 100644 --- a/numpy/core/src/multiarray/dlpack.c +++ b/numpy/core/src/multiarray/dlpack.c @@ -15,8 +15,7 @@ static void array_dlpack_deleter(DLManagedTensor *self) { PyArrayObject *array = (PyArrayObject *)self->manager_ctx; - // This will also free the strides as it's one allocation. - PyMem_Free(self->dl_tensor.shape); + // This will also free the shape and strides as it's one allocation. PyMem_Free(self); Py_XDECREF(array); } @@ -197,12 +196,17 @@ array_dlpack(PyArrayObject *self, return NULL; } - DLManagedTensor *managed = PyMem_Malloc(sizeof(DLManagedTensor)); - if (managed == NULL) { + // ensure alignment + int offset = sizeof(DLManagedTensor) % sizeof(void *); + void *ptr = PyMem_Malloc(sizeof(DLManagedTensor) + offset + + (sizeof(int64_t) * ndim * 2)); + if (ptr == NULL) { PyErr_NoMemory(); return NULL; } + DLManagedTensor *managed = ptr; + /* * Note: the `dlpack.h` header suggests/standardizes that `data` must be * 256-byte aligned. We ignore this intentionally, because `__dlpack__` @@ -221,12 +225,8 @@ array_dlpack(PyArrayObject *self, managed->dl_tensor.device = device; managed->dl_tensor.dtype = managed_dtype; - int64_t *managed_shape_strides = PyMem_Malloc(sizeof(int64_t) * ndim * 2); - if (managed_shape_strides == NULL) { - PyErr_NoMemory(); - PyMem_Free(managed); - return NULL; - } + int64_t *managed_shape_strides = (int64_t *)((char *)ptr + + sizeof(DLManagedTensor) + offset); int64_t *managed_shape = managed_shape_strides; int64_t *managed_strides = managed_shape_strides + ndim; @@ -249,8 +249,7 @@ array_dlpack(PyArrayObject *self, PyObject *capsule = PyCapsule_New(managed, NPY_DLPACK_CAPSULE_NAME, dlpack_capsule_deleter); if (capsule == NULL) { - PyMem_Free(managed); - PyMem_Free(managed_shape_strides); + PyMem_Free(ptr); return NULL; } diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py index 2ab55e903..203cf02c0 100644 --- a/numpy/core/tests/test_dlpack.py +++ b/numpy/core/tests/test_dlpack.py @@ -110,3 +110,8 @@ class TestDLPack: x.flags.writeable = False with pytest.raises(TypeError): x.__dlpack__() + + def test_ndim0(self): + x = np.array(1.0) + y = np._from_dlpack(x) + assert_array_equal(x, y) |
