summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/dlpack.c23
-rw-r--r--numpy/core/tests/test_dlpack.py5
2 files changed, 16 insertions, 12 deletions
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c
index 8491ed5b9..980b85395 100644
--- a/numpy/core/src/multiarray/dlpack.c
+++ b/numpy/core/src/multiarray/dlpack.c
@@ -15,8 +15,7 @@ static void
array_dlpack_deleter(DLManagedTensor *self)
{
PyArrayObject *array = (PyArrayObject *)self->manager_ctx;
- // This will also free the strides as it's one allocation.
- PyMem_Free(self->dl_tensor.shape);
+ // This will also free the shape and strides as it's one allocation.
PyMem_Free(self);
Py_XDECREF(array);
}
@@ -197,12 +196,17 @@ array_dlpack(PyArrayObject *self,
return NULL;
}
- DLManagedTensor *managed = PyMem_Malloc(sizeof(DLManagedTensor));
- if (managed == NULL) {
+ // ensure alignment
+ int offset = sizeof(DLManagedTensor) % sizeof(void *);
+ void *ptr = PyMem_Malloc(sizeof(DLManagedTensor) + offset +
+ (sizeof(int64_t) * ndim * 2));
+ if (ptr == NULL) {
PyErr_NoMemory();
return NULL;
}
+ DLManagedTensor *managed = ptr;
+
/*
* Note: the `dlpack.h` header suggests/standardizes that `data` must be
* 256-byte aligned. We ignore this intentionally, because `__dlpack__`
@@ -221,12 +225,8 @@ array_dlpack(PyArrayObject *self,
managed->dl_tensor.device = device;
managed->dl_tensor.dtype = managed_dtype;
- int64_t *managed_shape_strides = PyMem_Malloc(sizeof(int64_t) * ndim * 2);
- if (managed_shape_strides == NULL) {
- PyErr_NoMemory();
- PyMem_Free(managed);
- return NULL;
- }
+ int64_t *managed_shape_strides = (int64_t *)((char *)ptr +
+ sizeof(DLManagedTensor) + offset);
int64_t *managed_shape = managed_shape_strides;
int64_t *managed_strides = managed_shape_strides + ndim;
@@ -249,8 +249,7 @@ array_dlpack(PyArrayObject *self,
PyObject *capsule = PyCapsule_New(managed, NPY_DLPACK_CAPSULE_NAME,
dlpack_capsule_deleter);
if (capsule == NULL) {
- PyMem_Free(managed);
- PyMem_Free(managed_shape_strides);
+ PyMem_Free(ptr);
return NULL;
}
diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py
index 2ab55e903..203cf02c0 100644
--- a/numpy/core/tests/test_dlpack.py
+++ b/numpy/core/tests/test_dlpack.py
@@ -110,3 +110,8 @@ class TestDLPack:
x.flags.writeable = False
with pytest.raises(TypeError):
x.__dlpack__()
+
+ def test_ndim0(self):
+ x = np.array(1.0)
+ y = np._from_dlpack(x)
+ assert_array_equal(x, y)