summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi17
-rw-r--r--numpy/array_api/__init__.py4
-rw-r--r--numpy/array_api/_creation_functions.py2
-rw-r--r--numpy/core/_add_newdocs.py22
-rw-r--r--numpy/core/code_generators/genapi.py1
-rw-r--r--numpy/core/multiarray.py22
-rw-r--r--numpy/core/numeric.py6
-rw-r--r--numpy/core/setup.py3
-rw-r--r--numpy/core/src/common/dlpack/dlpack.h201
-rw-r--r--numpy/core/src/common/npy_dlpack.h28
-rw-r--r--numpy/core/src/multiarray/dlpack.c408
-rw-r--r--numpy/core/src/multiarray/methods.c9
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c5
-rw-r--r--numpy/core/tests/test_dlpack.py109
14 files changed, 820 insertions, 17 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 637f8578a..b2e9eec77 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -1413,6 +1413,7 @@ _SupportsBuffer = Union[
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
+_T_contra = TypeVar("_T_contra", contravariant=True)
_2Tuple = Tuple[_T, _T]
_CastingKind = L["no", "equiv", "safe", "same_kind", "unsafe"]
@@ -1432,6 +1433,10 @@ _ArrayTD64_co = NDArray[Union[bool_, integer[Any], timedelta64]]
# Introduce an alias for `dtype` to avoid naming conflicts.
_dtype = dtype
+# `builtins.PyCapsule` unfortunately lacks annotations as of the moment;
+# use `Any` as a stopgap measure
+_PyCapsule = Any
+
class _SupportsItem(Protocol[_T_co]):
def item(self, args: Any, /) -> _T_co: ...
@@ -2439,6 +2444,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __ior__(self: NDArray[signedinteger[_NBit1]], other: _ArrayLikeInt_co) -> NDArray[signedinteger[_NBit1]]: ...
@overload
def __ior__(self: NDArray[object_], other: Any) -> NDArray[object_]: ...
+ @overload
+ def __ior__(self: NDArray[_ScalarType], other: _RecursiveSequence) -> NDArray[_ScalarType]: ...
+ @overload
+ def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
+ @overload
+ def __dlpack_device__(self) -> Tuple[int, L[0]]: ...
# Keep `dtype` at the bottom to avoid name conflicts with `np.dtype`
@property
@@ -4320,3 +4331,9 @@ class chararray(ndarray[_ShapeType, _CharDType]):
# NOTE: Deprecated
# class MachAr: ...
+
+class _SupportsDLPack(Protocol[_T_contra]):
+ def __dlpack__(self, *, stream: None | _T_contra = ...) -> _PyCapsule: ...
+
+def _from_dlpack(__obj: _SupportsDLPack[None]) -> NDArray[Any]: ...
+
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index d8b29057e..89f5e9cba 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -136,7 +136,7 @@ from ._creation_functions import (
empty,
empty_like,
eye,
- from_dlpack,
+ _from_dlpack,
full,
full_like,
linspace,
@@ -155,7 +155,7 @@ __all__ += [
"empty",
"empty_like",
"eye",
- "from_dlpack",
+ "_from_dlpack",
"full",
"full_like",
"linspace",
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index e36807468..c3644ac2c 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -151,7 +151,7 @@ def eye(
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
-def from_dlpack(x: object, /) -> Array:
+def _from_dlpack(x: object, /) -> Array:
# Note: dlpack support is not yet implemented on Array
raise NotImplementedError("DLPack support is not yet implemented")
diff --git a/numpy/core/_add_newdocs.py b/numpy/core/_add_newdocs.py
index c8a24db0c..cae5bc281 100644
--- a/numpy/core/_add_newdocs.py
+++ b/numpy/core/_add_newdocs.py
@@ -1573,6 +1573,19 @@ add_newdoc('numpy.core.multiarray', 'frombuffer',
array_function_like_doc,
))
+add_newdoc('numpy.core.multiarray', '_from_dlpack',
+ """
+ _from_dlpack(x, /)
+
+ Create a NumPy array from an object implementing the ``__dlpack__``
+ protocol.
+
+ See Also
+ --------
+ `Array API documentation
+ <https://data-apis.org/array-api/latest/design_topics/data_interchange.html#syntax-for-data-interchange-with-dlpack>`_
+ """)
+
add_newdoc('numpy.core', 'fastCopyAndTranspose',
"""_fastCopyAndTranspose(a)""")
@@ -2263,6 +2276,15 @@ add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_priority__',
add_newdoc('numpy.core.multiarray', 'ndarray', ('__array_struct__',
"""Array protocol: C-struct side."""))
+add_newdoc('numpy.core.multiarray', 'ndarray', ('__dlpack__',
+ """a.__dlpack__(*, stream=None)
+
+ DLPack Protocol: Part of the Array API."""))
+
+add_newdoc('numpy.core.multiarray', 'ndarray', ('__dlpack_device__',
+ """a.__dlpack_device__()
+
+ DLPack Protocol: Part of the Array API."""))
add_newdoc('numpy.core.multiarray', 'ndarray', ('base',
"""
diff --git a/numpy/core/code_generators/genapi.py b/numpy/core/code_generators/genapi.py
index c2458c2b5..b401ee6a5 100644
--- a/numpy/core/code_generators/genapi.py
+++ b/numpy/core/code_generators/genapi.py
@@ -41,6 +41,7 @@ API_FILES = [join('multiarray', 'alloc.c'),
join('multiarray', 'datetime_busdaycal.c'),
join('multiarray', 'datetime_strings.c'),
join('multiarray', 'descriptor.c'),
+ join('multiarray', 'dlpack.c'),
join('multiarray', 'dtypemeta.c'),
join('multiarray', 'einsum.c.src'),
join('multiarray', 'flagsobject.c'),
diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py
index 351cd3a1b..f96274263 100644
--- a/numpy/core/multiarray.py
+++ b/numpy/core/multiarray.py
@@ -14,8 +14,9 @@ from ._multiarray_umath import * # noqa: F403
# do not change them. issue gh-15518
# _get_ndarray_c_version is semi-public, on purpose not added to __all__
from ._multiarray_umath import (
- _fastCopyAndTranspose, _flagdict, _insert, _reconstruct, _vec_string,
- _ARRAY_API, _monotonicity, _get_ndarray_c_version, _set_madvise_hugepage,
+ _fastCopyAndTranspose, _flagdict, _from_dlpack, _insert, _reconstruct,
+ _vec_string, _ARRAY_API, _monotonicity, _get_ndarray_c_version,
+ _set_madvise_hugepage,
)
__all__ = [
@@ -23,18 +24,18 @@ __all__ = [
'ITEM_HASOBJECT', 'ITEM_IS_POINTER', 'LIST_PICKLE', 'MAXDIMS',
'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'NEEDS_INIT', 'NEEDS_PYAPI',
'RAISE', 'USE_GETITEM', 'USE_SETITEM', 'WRAP', '_fastCopyAndTranspose',
- '_flagdict', '_insert', '_reconstruct', '_vec_string', '_monotonicity',
- 'add_docstring', 'arange', 'array', 'asarray', 'asanyarray',
- 'ascontiguousarray', 'asfortranarray', 'bincount', 'broadcast',
- 'busday_count', 'busday_offset', 'busdaycalendar', 'can_cast',
+ '_flagdict', '_from_dlpack', '_insert', '_reconstruct', '_vec_string',
+ '_monotonicity', 'add_docstring', 'arange', 'array', 'asarray',
+ 'asanyarray', 'ascontiguousarray', 'asfortranarray', 'bincount',
+ 'broadcast', 'busday_count', 'busday_offset', 'busdaycalendar', 'can_cast',
'compare_chararrays', 'concatenate', 'copyto', 'correlate', 'correlate2',
'count_nonzero', 'c_einsum', 'datetime_as_string', 'datetime_data',
'dot', 'dragon4_positional', 'dragon4_scientific', 'dtype',
'empty', 'empty_like', 'error', 'flagsobj', 'flatiter', 'format_longfloat',
- 'frombuffer', 'fromfile', 'fromiter', 'fromstring', 'get_handler_name',
- 'inner', 'interp', 'interp_complex', 'is_busday', 'lexsort',
- 'matmul', 'may_share_memory', 'min_scalar_type', 'ndarray', 'nditer',
- 'nested_iters', 'normalize_axis_index', 'packbits',
+ 'frombuffer', 'fromfile', 'fromiter', 'fromstring',
+ 'get_handler_name', 'inner', 'interp', 'interp_complex', 'is_busday',
+ 'lexsort', 'matmul', 'may_share_memory', 'min_scalar_type', 'ndarray',
+ 'nditer', 'nested_iters', 'normalize_axis_index', 'packbits',
'promote_types', 'putmask', 'ravel_multi_index', 'result_type', 'scalar',
'set_datetimeparse_function', 'set_legacy_print_mode', 'set_numeric_ops',
'set_string_function', 'set_typeDict', 'shares_memory',
@@ -46,6 +47,7 @@ _reconstruct.__module__ = 'numpy.core.multiarray'
scalar.__module__ = 'numpy.core.multiarray'
+_from_dlpack.__module__ = 'numpy'
arange.__module__ = 'numpy'
array.__module__ = 'numpy'
asarray.__module__ = 'numpy'
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 1654e8364..344d40d93 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -13,8 +13,8 @@ from .multiarray import (
WRAP, arange, array, asarray, asanyarray, ascontiguousarray,
asfortranarray, broadcast, can_cast, compare_chararrays,
concatenate, copyto, dot, dtype, empty,
- empty_like, flatiter, frombuffer, fromfile, fromiter, fromstring,
- inner, lexsort, matmul, may_share_memory,
+ empty_like, flatiter, frombuffer, _from_dlpack, fromfile, fromiter,
+ fromstring, inner, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
zeros, normalize_axis_index)
@@ -41,7 +41,7 @@ __all__ = [
'newaxis', 'ndarray', 'flatiter', 'nditer', 'nested_iters', 'ufunc',
'arange', 'array', 'asarray', 'asanyarray', 'ascontiguousarray',
'asfortranarray', 'zeros', 'count_nonzero', 'empty', 'broadcast', 'dtype',
- 'fromstring', 'fromfile', 'frombuffer', 'where',
+ 'fromstring', 'fromfile', 'frombuffer', '_from_dlpack', 'where',
'argwhere', 'copyto', 'concatenate', 'fastCopyAndTranspose', 'lexsort',
'set_numeric_ops', 'can_cast', 'promote_types', 'min_scalar_type',
'result_type', 'isfortran', 'empty_like', 'zeros_like', 'ones_like',
diff --git a/numpy/core/setup.py b/numpy/core/setup.py
index 3e1ed4c9b..2c99060ec 100644
--- a/numpy/core/setup.py
+++ b/numpy/core/setup.py
@@ -740,6 +740,7 @@ def configuration(parent_package='',top_path=None):
#######################################################################
common_deps = [
+ join('src', 'common', 'dlpack', 'dlpack.h'),
join('src', 'common', 'array_assign.h'),
join('src', 'common', 'binop_override.h'),
join('src', 'common', 'cblasfuncs.h'),
@@ -749,6 +750,7 @@ def configuration(parent_package='',top_path=None):
join('src', 'common', 'npy_cblas.h'),
join('src', 'common', 'npy_config.h'),
join('src', 'common', 'npy_ctypes.h'),
+ join('src', 'common', 'npy_dlpack.h'),
join('src', 'common', 'npy_extint128.h'),
join('src', 'common', 'npy_import.h'),
join('src', 'common', 'npy_hashtable.h'),
@@ -881,6 +883,7 @@ def configuration(parent_package='',top_path=None):
join('src', 'multiarray', 'datetime_busday.c'),
join('src', 'multiarray', 'datetime_busdaycal.c'),
join('src', 'multiarray', 'descriptor.c'),
+ join('src', 'multiarray', 'dlpack.c'),
join('src', 'multiarray', 'dtypemeta.c'),
join('src', 'multiarray', 'dragon4.c'),
join('src', 'multiarray', 'dtype_transfer.c'),
diff --git a/numpy/core/src/common/dlpack/dlpack.h b/numpy/core/src/common/dlpack/dlpack.h
new file mode 100644
index 000000000..29209aee1
--- /dev/null
+++ b/numpy/core/src/common/dlpack/dlpack.h
@@ -0,0 +1,201 @@
+// Taken from:
+// https://github.com/dmlc/dlpack/blob/9b6176fdecb55e9bf39b16f08b96913ed3f275b4/include/dlpack/dlpack.h
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file dlpack.h
+ * \brief The common header of DLPack.
+ */
+#ifndef DLPACK_DLPACK_H_
+#define DLPACK_DLPACK_H_
+
+#ifdef __cplusplus
+#define DLPACK_EXTERN_C extern "C"
+#else
+#define DLPACK_EXTERN_C
+#endif
+
+/*! \brief The current version of dlpack */
+#define DLPACK_VERSION 050
+
+/*! \brief DLPACK_DLL prefix for windows */
+#ifdef _WIN32
+#ifdef DLPACK_EXPORTS
+#define DLPACK_DLL __declspec(dllexport)
+#else
+#define DLPACK_DLL __declspec(dllimport)
+#endif
+#else
+#define DLPACK_DLL
+#endif
+
+#include <stdint.h>
+#include <stddef.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*!
+ * \brief The device type in DLDevice.
+ */
+typedef enum {
+ /*! \brief CPU device */
+ kDLCPU = 1,
+ /*! \brief CUDA GPU device */
+ kDLCUDA = 2,
+ /*!
+ * \brief Pinned CUDA CPU memory by cudaMallocHost
+ */
+ kDLCUDAHost = 3,
+ /*! \brief OpenCL devices. */
+ kDLOpenCL = 4,
+ /*! \brief Vulkan buffer for next generation graphics. */
+ kDLVulkan = 7,
+ /*! \brief Metal for Apple GPU. */
+ kDLMetal = 8,
+ /*! \brief Verilog simulator buffer */
+ kDLVPI = 9,
+ /*! \brief ROCm GPUs for AMD GPUs */
+ kDLROCM = 10,
+ /*!
+ * \brief Pinned ROCm CPU memory allocated by hipMallocHost
+ */
+ kDLROCMHost = 11,
+ /*!
+ * \brief Reserved extension device type,
+ * used for quickly test extension device
+ * The semantics can differ depending on the implementation.
+ */
+ kDLExtDev = 12,
+ /*!
+ * \brief CUDA managed/unified memory allocated by cudaMallocManaged
+ */
+ kDLCUDAManaged = 13,
+} DLDeviceType;
+
+/*!
+ * \brief A Device for Tensor and operator.
+ */
+typedef struct {
+ /*! \brief The device type used in the device. */
+ DLDeviceType device_type;
+ /*!
+ * \brief The device index.
+ * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
+ */
+ int device_id;
+} DLDevice;
+
+/*!
+ * \brief The type code options DLDataType.
+ */
+typedef enum {
+ /*! \brief signed integer */
+ kDLInt = 0U,
+ /*! \brief unsigned integer */
+ kDLUInt = 1U,
+ /*! \brief IEEE floating point */
+ kDLFloat = 2U,
+ /*!
+ * \brief Opaque handle type, reserved for testing purposes.
+ * Frameworks need to agree on the handle data type for the exchange to be well-defined.
+ */
+ kDLOpaqueHandle = 3U,
+ /*! \brief bfloat16 */
+ kDLBfloat = 4U,
+ /*!
+ * \brief complex number
+ * (C/C++/Python layout: compact struct per complex number)
+ */
+ kDLComplex = 5U,
+} DLDataTypeCode;
+
+/*!
+ * \brief The data type the tensor can hold.
+ *
+ * Examples
+ * - float: type_code = 2, bits = 32, lanes=1
+ * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
+ * - int8: type_code = 0, bits = 8, lanes=1
+ * - std::complex<float>: type_code = 5, bits = 64, lanes = 1
+ */
+typedef struct {
+ /*!
+ * \brief Type code of base types.
+ * We keep it uint8_t instead of DLDataTypeCode for minimal memory
+ * footprint, but the value should be one of DLDataTypeCode enum values.
+ * */
+ uint8_t code;
+ /*!
+ * \brief Number of bits, common choices are 8, 16, 32.
+ */
+ uint8_t bits;
+ /*! \brief Number of lanes in the type, used for vector types. */
+ uint16_t lanes;
+} DLDataType;
+
+/*!
+ * \brief Plain C Tensor object, does not manage memory.
+ */
+typedef struct {
+ /*!
+ * \brief The opaque data pointer points to the allocated data. This will be
+ * CUDA device pointer or cl_mem handle in OpenCL. This pointer is always
+ * aligned to 256 bytes as in CUDA.
+ *
+ * For given DLTensor, the size of memory required to store the contents of
+ * data is calculated as follows:
+ *
+ * \code{.c}
+ * static inline size_t GetDataSize(const DLTensor* t) {
+ * size_t size = 1;
+ * for (tvm_index_t i = 0; i < t->ndim; ++i) {
+ * size *= t->shape[i];
+ * }
+ * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
+ * return size;
+ * }
+ * \endcode
+ */
+ void* data;
+ /*! \brief The device of the tensor */
+ DLDevice device;
+ /*! \brief Number of dimensions */
+ int ndim;
+ /*! \brief The data type of the pointer*/
+ DLDataType dtype;
+ /*! \brief The shape of the tensor */
+ int64_t* shape;
+ /*!
+ * \brief strides of the tensor (in number of elements, not bytes)
+ * can be NULL, indicating tensor is compact and row-majored.
+ */
+ int64_t* strides;
+ /*! \brief The offset in bytes to the beginning pointer to data */
+ uint64_t byte_offset;
+} DLTensor;
+
+/*!
+ * \brief C Tensor object, manage memory of DLTensor. This data structure is
+ * intended to facilitate the borrowing of DLTensor by another framework. It is
+ * not meant to transfer the tensor. When the borrowing framework doesn't need
+ * the tensor, it should call the deleter to notify the host that the resource
+ * is no longer needed.
+ */
+typedef struct DLManagedTensor {
+ /*! \brief DLTensor which is being memory managed */
+ DLTensor dl_tensor;
+ /*! \brief the context of the original host framework of DLManagedTensor in
+ * which DLManagedTensor is used in the framework. It can also be NULL.
+ */
+ void * manager_ctx;
+ /*! \brief Destructor signature void (*)(void*) - this should be called
+ * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
+ * if there is no way for the caller to provide a reasonable destructor.
+ * The destructors deletes the argument self as well.
+ */
+ void (*deleter)(struct DLManagedTensor * self);
+} DLManagedTensor;
+#ifdef __cplusplus
+} // DLPACK_EXTERN_C
+#endif
+#endif // DLPACK_DLPACK_H_
diff --git a/numpy/core/src/common/npy_dlpack.h b/numpy/core/src/common/npy_dlpack.h
new file mode 100644
index 000000000..14ca352c0
--- /dev/null
+++ b/numpy/core/src/common/npy_dlpack.h
@@ -0,0 +1,28 @@
+#include "Python.h"
+#include "dlpack/dlpack.h"
+
+#ifndef NPY_DLPACK_H
+#define NPY_DLPACK_H
+
+// Part of the Array API specification.
+#define NPY_DLPACK_CAPSULE_NAME "dltensor"
+#define NPY_DLPACK_USED_CAPSULE_NAME "used_dltensor"
+
+// Used internally by NumPy to store a base object
+// as it has to release a reference to the original
+// capsule.
+#define NPY_DLPACK_INTERNAL_CAPSULE_NAME "numpy_dltensor"
+
+PyObject *
+array_dlpack(PyArrayObject *self, PyObject *const *args, Py_ssize_t len_args,
+ PyObject *kwnames);
+
+
+PyObject *
+array_dlpack_device(PyArrayObject *self, PyObject *NPY_UNUSED(args));
+
+
+NPY_NO_EXPORT PyObject *
+_from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj);
+
+#endif
diff --git a/numpy/core/src/multiarray/dlpack.c b/numpy/core/src/multiarray/dlpack.c
new file mode 100644
index 000000000..291e60a22
--- /dev/null
+++ b/numpy/core/src/multiarray/dlpack.c
@@ -0,0 +1,408 @@
+#define NPY_NO_DEPRECATED_API NPY_API_VERSION
+#define _MULTIARRAYMODULE
+
+#define PY_SSIZE_T_CLEAN
+#include <Python.h>
+#include <dlpack/dlpack.h>
+
+#include "numpy/arrayobject.h"
+#include "common/npy_argparse.h"
+
+#include "common/dlpack/dlpack.h"
+#include "common/npy_dlpack.h"
+
+static void
+array_dlpack_deleter(DLManagedTensor *self)
+{
+ PyArrayObject *array = (PyArrayObject *)self->manager_ctx;
+ // This will also free the strides as it's one allocation.
+ PyMem_Free(self->dl_tensor.shape);
+ PyMem_Free(self);
+ Py_XDECREF(array);
+}
+
+/* This is exactly as mandated by dlpack */
+static void dlpack_capsule_deleter(PyObject *self) {
+ if (PyCapsule_IsValid(self, NPY_DLPACK_USED_CAPSULE_NAME)) {
+ return;
+ }
+
+ /* an exception may be in-flight, we must save it in case we create another one */
+ PyObject *type, *value, *traceback;
+ PyErr_Fetch(&type, &value, &traceback);
+
+ DLManagedTensor *managed =
+ (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_CAPSULE_NAME);
+ if (managed == NULL) {
+ PyErr_WriteUnraisable(self);
+ goto done;
+ }
+ /*
+ * the spec says the deleter can be NULL if there is no way for the caller
+ * to provide a reasonable destructor.
+ */
+ if (managed->deleter) {
+ managed->deleter(managed);
+ /* TODO: is the deleter allowed to set a python exception? */
+ assert(!PyErr_Occurred());
+ }
+
+done:
+ PyErr_Restore(type, value, traceback);
+}
+
+/* used internally, almost identical to dlpack_capsule_deleter() */
+static void array_dlpack_internal_capsule_deleter(PyObject *self)
+{
+ /* an exception may be in-flight, we must save it in case we create another one */
+ PyObject *type, *value, *traceback;
+ PyErr_Fetch(&type, &value, &traceback);
+
+ DLManagedTensor *managed =
+ (DLManagedTensor *)PyCapsule_GetPointer(self, NPY_DLPACK_INTERNAL_CAPSULE_NAME);
+ if (managed == NULL) {
+ PyErr_WriteUnraisable(self);
+ goto done;
+ }
+ /*
+ * the spec says the deleter can be NULL if there is no way for the caller
+ * to provide a reasonable destructor.
+ */
+ if (managed->deleter) {
+ managed->deleter(managed);
+ /* TODO: is the deleter allowed to set a python exception? */
+ assert(!PyErr_Occurred());
+ }
+
+done:
+ PyErr_Restore(type, value, traceback);
+}
+
+
+// This function cannot return NULL, but it can fail,
+// So call PyErr_Occurred to check if it failed after
+// calling it.
+static DLDevice
+array_get_dl_device(PyArrayObject *self) {
+ DLDevice ret;
+ ret.device_type = kDLCPU;
+ ret.device_id = 0;
+ PyObject *base = PyArray_BASE(self);
+ // The outer if is due to the fact that NumPy arrays are on the CPU
+ // by default (if not created from DLPack).
+ if (PyCapsule_IsValid(base, NPY_DLPACK_INTERNAL_CAPSULE_NAME)) {
+ DLManagedTensor *managed = PyCapsule_GetPointer(
+ base, NPY_DLPACK_INTERNAL_CAPSULE_NAME);
+ if (managed == NULL) {
+ return ret;
+ }
+ return managed->dl_tensor.device;
+ }
+ return ret;
+}
+
+
+PyObject *
+array_dlpack(PyArrayObject *self,
+ PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
+{
+ PyObject *stream = Py_None;
+ NPY_PREPARE_ARGPARSER;
+ if (npy_parse_arguments("__dlpack__", args, len_args, kwnames,
+ "$stream", NULL, &stream, NULL, NULL, NULL)) {
+ return NULL;
+ }
+
+ if (stream != Py_None) {
+ PyErr_SetString(PyExc_RuntimeError, "NumPy only supports "
+ "stream=None.");
+ return NULL;
+ }
+
+ if ( !(PyArray_FLAGS(self) & NPY_ARRAY_WRITEABLE)) {
+ PyErr_SetString(PyExc_TypeError, "NumPy currently only supports "
+ "dlpack for writeable arrays");
+ return NULL;
+ }
+
+ npy_intp itemsize = PyArray_ITEMSIZE(self);
+ int ndim = PyArray_NDIM(self);
+ npy_intp *strides = PyArray_STRIDES(self);
+ npy_intp *shape = PyArray_SHAPE(self);
+
+ if (!PyArray_IS_C_CONTIGUOUS(self) && PyArray_SIZE(self) != 1) {
+ for (int i = 0; i < ndim; ++i) {
+ if (strides[i] % itemsize != 0) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "DLPack only supports strides which are a multiple of "
+ "itemsize.");
+ return NULL;
+ }
+ }
+ }
+
+ DLDataType managed_dtype;
+ PyArray_Descr *dtype = PyArray_DESCR(self);
+
+ if (PyDataType_ISBYTESWAPPED(dtype)) {
+ PyErr_SetString(PyExc_TypeError, "DLPack only supports native "
+ "byte swapping.");
+ return NULL;
+ }
+
+ managed_dtype.bits = 8 * itemsize;
+ managed_dtype.lanes = 1;
+
+ if (PyDataType_ISSIGNED(dtype)) {
+ managed_dtype.code = kDLInt;
+ }
+ else if (PyDataType_ISUNSIGNED(dtype)) {
+ managed_dtype.code = kDLUInt;
+ }
+ else if (PyDataType_ISFLOAT(dtype)) {
+ // We can't be sure that the dtype is
+ // IEEE or padded.
+ if (itemsize > 8) {
+ PyErr_SetString(PyExc_TypeError, "DLPack only supports IEEE "
+ "floating point types without padding.");
+ return NULL;
+ }
+ managed_dtype.code = kDLFloat;
+ }
+ else if (PyDataType_ISCOMPLEX(dtype)) {
+ // We can't be sure that the dtype is
+ // IEEE or padded.
+ if (itemsize > 16) {
+ PyErr_SetString(PyExc_TypeError, "DLPack only supports IEEE "
+ "complex point types without padding.");
+ return NULL;
+ }
+ managed_dtype.code = kDLComplex;
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "DLPack only supports signed/unsigned integers, float "
+ "and complex dtypes.");
+ return NULL;
+ }
+
+ DLDevice device = array_get_dl_device(self);
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+
+ DLManagedTensor *managed = PyMem_Malloc(sizeof(DLManagedTensor));
+ if (managed == NULL) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+
+ /*
+ * Note: the `dlpack.h` header suggests/standardizes that `data` must be
+ * 256-byte aligned. We ignore this intentionally, because `__dlpack__`
+ * standardizes that `byte_offset` must be 0 (for now) to not break pytorch:
+ * https://github.com/data-apis/array-api/issues/293#issuecomment-964111413
+ *
+ * We further assume that exporting fully unaligned data is OK even without
+ * `byte_offset` since the standard does not reject it.
+ * Presumably, pytorch will support importing `byte_offset != 0` and NumPy
+ * can choose to use it starting about 2023. At that point, it may be
+ * that NumPy MUST use `byte_offset` to adhere to the standard (as
+ * specified in the header)!
+ */
+ managed->dl_tensor.data = PyArray_DATA(self);
+ managed->dl_tensor.byte_offset = 0;
+ managed->dl_tensor.device = device;
+ managed->dl_tensor.dtype = managed_dtype;
+
+ int64_t *managed_shape_strides = PyMem_Malloc(sizeof(int64_t) * ndim * 2);
+ if (managed_shape_strides == NULL) {
+ PyErr_NoMemory();
+ PyMem_Free(managed);
+ return NULL;
+ }
+
+ int64_t *managed_shape = managed_shape_strides;
+ int64_t *managed_strides = managed_shape_strides + ndim;
+ for (int i = 0; i < ndim; ++i) {
+ managed_shape[i] = shape[i];
+ // Strides in DLPack are items; in NumPy are bytes.
+ managed_strides[i] = strides[i] / itemsize;
+ }
+
+ managed->dl_tensor.ndim = ndim;
+ managed->dl_tensor.shape = managed_shape;
+ managed->dl_tensor.strides = NULL;
+ if (PyArray_SIZE(self) != 1 && !PyArray_IS_C_CONTIGUOUS(self)) {
+ managed->dl_tensor.strides = managed_strides;
+ }
+ managed->dl_tensor.byte_offset = 0;
+ managed->manager_ctx = self;
+ managed->deleter = array_dlpack_deleter;
+
+ PyObject *capsule = PyCapsule_New(managed, NPY_DLPACK_CAPSULE_NAME,
+ dlpack_capsule_deleter);
+ if (capsule == NULL) {
+ PyMem_Free(managed);
+ PyMem_Free(managed_shape_strides);
+ return NULL;
+ }
+
+ // the capsule holds a reference
+ Py_INCREF(self);
+ return capsule;
+}
+
+PyObject *
+array_dlpack_device(PyArrayObject *self, PyObject *NPY_UNUSED(args))
+{
+ DLDevice device = array_get_dl_device(self);
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ return Py_BuildValue("ii", device.device_type, device.device_id);
+}
+
+NPY_NO_EXPORT PyObject *
+_from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj) {
+ PyObject *capsule = PyObject_CallMethod((PyObject *)obj->ob_type,
+ "__dlpack__", "O", obj);
+ if (capsule == NULL) {
+ return NULL;
+ }
+
+ DLManagedTensor *managed =
+ (DLManagedTensor *)PyCapsule_GetPointer(capsule,
+ NPY_DLPACK_CAPSULE_NAME);
+
+ if (managed == NULL) {
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
+ const int ndim = managed->dl_tensor.ndim;
+ if (ndim > NPY_MAXDIMS) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "maxdims of DLPack tensor is higher than the supported "
+ "maxdims.");
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
+ DLDeviceType device_type = managed->dl_tensor.device.device_type;
+ if (device_type != kDLCPU &&
+ device_type != kDLCUDAHost &&
+ device_type != kDLROCMHost &&
+ device_type != kDLCUDAManaged) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Unsupported device in DLTensor.");
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
+ if (managed->dl_tensor.dtype.lanes != 1) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Unsupported lanes in DLTensor dtype.");
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
+ int typenum = -1;
+ const uint8_t bits = managed->dl_tensor.dtype.bits;
+ const npy_intp itemsize = bits / 8;
+ switch (managed->dl_tensor.dtype.code) {
+ case kDLInt:
+ switch (bits)
+ {
+ case 8: typenum = NPY_INT8; break;
+ case 16: typenum = NPY_INT16; break;
+ case 32: typenum = NPY_INT32; break;
+ case 64: typenum = NPY_INT64; break;
+ }
+ break;
+ case kDLUInt:
+ switch (bits)
+ {
+ case 8: typenum = NPY_UINT8; break;
+ case 16: typenum = NPY_UINT16; break;
+ case 32: typenum = NPY_UINT32; break;
+ case 64: typenum = NPY_UINT64; break;
+ }
+ break;
+ case kDLFloat:
+ switch (bits)
+ {
+ case 16: typenum = NPY_FLOAT16; break;
+ case 32: typenum = NPY_FLOAT32; break;
+ case 64: typenum = NPY_FLOAT64; break;
+ }
+ break;
+ case kDLComplex:
+ switch (bits)
+ {
+ case 64: typenum = NPY_COMPLEX64; break;
+ case 128: typenum = NPY_COMPLEX128; break;
+ }
+ break;
+ }
+
+ if (typenum == -1) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Unsupported dtype in DLTensor.");
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
+ npy_intp shape[NPY_MAXDIMS];
+ npy_intp strides[NPY_MAXDIMS];
+
+ for (int i = 0; i < ndim; ++i) {
+ shape[i] = managed->dl_tensor.shape[i];
+ // DLPack has elements as stride units, NumPy has bytes.
+ if (managed->dl_tensor.strides != NULL) {
+ strides[i] = managed->dl_tensor.strides[i] * itemsize;
+ }
+ }
+
+ char *data = (char *)managed->dl_tensor.data +
+ managed->dl_tensor.byte_offset;
+
+ PyArray_Descr *descr = PyArray_DescrFromType(typenum);
+ if (descr == NULL) {
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
+ PyObject *ret = PyArray_NewFromDescr(&PyArray_Type, descr, ndim, shape,
+ managed->dl_tensor.strides != NULL ? strides : NULL, data, 0, NULL);
+ if (ret == NULL) {
+ Py_DECREF(capsule);
+ return NULL;
+ }
+
+ PyObject *new_capsule = PyCapsule_New(managed,
+ NPY_DLPACK_INTERNAL_CAPSULE_NAME,
+ array_dlpack_internal_capsule_deleter);
+ if (new_capsule == NULL) {
+ Py_DECREF(capsule);
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ if (PyArray_SetBaseObject((PyArrayObject *)ret, new_capsule) < 0) {
+ Py_DECREF(capsule);
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ if (PyCapsule_SetName(capsule, NPY_DLPACK_USED_CAPSULE_NAME) < 0) {
+ Py_DECREF(capsule);
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ Py_DECREF(capsule);
+ return ret;
+}
+
+
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 8e2cd09eb..2ca8d9288 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -26,6 +26,7 @@
#include "shape.h"
#include "strfuncs.h"
#include "array_assign.h"
+#include "npy_dlpack.h"
#include "methods.h"
#include "alloc.h"
@@ -2975,5 +2976,13 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
{"view",
(PyCFunction)array_view,
METH_FASTCALL | METH_KEYWORDS, NULL},
+ // For data interchange between libraries
+ {"__dlpack__",
+ (PyCFunction)array_dlpack,
+ METH_FASTCALL | METH_KEYWORDS, NULL},
+
+ {"__dlpack_device__",
+ (PyCFunction)array_dlpack_device,
+ METH_NOARGS, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index fcf8d945f..84179d5f0 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -70,6 +70,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "get_attr_string.h"
#include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */
+#include "npy_dlpack.h"
+
/*
*****************************************************************************
** INCLUDE GENERATED CODE **
@@ -4231,7 +4233,6 @@ _reload_guard(PyObject *NPY_UNUSED(self)) {
Py_RETURN_NONE;
}
-
static struct PyMethodDef array_module_methods[] = {
{"_get_implementing_args",
(PyCFunction)array__get_implementing_args,
@@ -4445,6 +4446,8 @@ static struct PyMethodDef array_module_methods[] = {
{"_reload_guard", (PyCFunction)_reload_guard,
METH_NOARGS,
"Give a warning on reload and big warning in sub-interpreters."},
+ {"_from_dlpack", (PyCFunction)_from_dlpack,
+ METH_O, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};
diff --git a/numpy/core/tests/test_dlpack.py b/numpy/core/tests/test_dlpack.py
new file mode 100644
index 000000000..f848b2008
--- /dev/null
+++ b/numpy/core/tests/test_dlpack.py
@@ -0,0 +1,109 @@
+import sys
+import pytest
+
+import numpy as np
+from numpy.testing import assert_array_equal, IS_PYPY
+
+
+class TestDLPack:
+ @pytest.mark.skipif(IS_PYPY, reason="PyPy can't get refcounts.")
+ def test_dunder_dlpack_refcount(self):
+ x = np.arange(5)
+ y = x.__dlpack__()
+ assert sys.getrefcount(x) == 3
+ del y
+ assert sys.getrefcount(x) == 2
+
+ def test_dunder_dlpack_stream(self):
+ x = np.arange(5)
+ x.__dlpack__(stream=None)
+
+ with pytest.raises(RuntimeError):
+ x.__dlpack__(stream=1)
+
+ def test_strides_not_multiple_of_itemsize(self):
+ dt = np.dtype([('int', np.int32), ('char', np.int8)])
+ y = np.zeros((5,), dtype=dt)
+ z = y['int']
+
+ with pytest.raises(RuntimeError):
+ np._from_dlpack(z)
+
+ @pytest.mark.skipif(IS_PYPY, reason="PyPy can't get refcounts.")
+ def test_from_dlpack_refcount(self):
+ x = np.arange(5)
+ y = np._from_dlpack(x)
+ assert sys.getrefcount(x) == 3
+ del y
+ assert sys.getrefcount(x) == 2
+
+ @pytest.mark.parametrize("dtype", [
+ np.int8, np.int16, np.int32, np.int64,
+ np.uint8, np.uint16, np.uint32, np.uint64,
+ np.float16, np.float32, np.float64,
+ np.complex64, np.complex128
+ ])
+ def test_dtype_passthrough(self, dtype):
+ x = np.arange(5, dtype=dtype)
+ y = np._from_dlpack(x)
+
+ assert y.dtype == x.dtype
+ assert_array_equal(x, y)
+
+ def test_invalid_dtype(self):
+ x = np.asarray(np.datetime64('2021-05-27'))
+
+ with pytest.raises(TypeError):
+ np._from_dlpack(x)
+
+ def test_invalid_byte_swapping(self):
+ dt = np.dtype('=i8').newbyteorder()
+ x = np.arange(5, dtype=dt)
+
+ with pytest.raises(TypeError):
+ np._from_dlpack(x)
+
+ def test_non_contiguous(self):
+ x = np.arange(25).reshape((5, 5))
+
+ y1 = x[0]
+ assert_array_equal(y1, np._from_dlpack(y1))
+
+ y2 = x[:, 0]
+ assert_array_equal(y2, np._from_dlpack(y2))
+
+ y3 = x[1, :]
+ assert_array_equal(y3, np._from_dlpack(y3))
+
+ y4 = x[1]
+ assert_array_equal(y4, np._from_dlpack(y4))
+
+ y5 = np.diagonal(x).copy()
+ assert_array_equal(y5, np._from_dlpack(y5))
+
+ @pytest.mark.parametrize("ndim", range(33))
+ def test_higher_dims(self, ndim):
+ shape = (1,) * ndim
+ x = np.zeros(shape, dtype=np.float64)
+
+ assert shape == np._from_dlpack(x).shape
+
+ def test_dlpack_device(self):
+ x = np.arange(5)
+ assert x.__dlpack_device__() == (1, 0)
+ assert np._from_dlpack(x).__dlpack_device__() == (1, 0)
+
+ def dlpack_deleter_exception(self):
+ x = np.arange(5)
+ _ = x.__dlpack__()
+ raise RuntimeError
+
+ def test_dlpack_destructor_exception(self):
+ with pytest.raises(RuntimeError):
+ self.dlpack_deleter_exception()
+
+ def test_readonly(self):
+ x = np.arange(5)
+ x.flags.writeable = False
+ with pytest.raises(TypeError):
+ x.__dlpack__()