diff options
Diffstat (limited to 'numpy/array_api/_creation_functions.py')
-rw-r--r-- | numpy/array_api/_creation_functions.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index d760bf2fc..741498ff6 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: Device, Dtype, NestedSequence, - SupportsDLPack, SupportsBufferProtocol, ) from collections.abc import Sequence @@ -36,7 +35,6 @@ def asarray( int, float, NestedSequence[bool | int | float], - SupportsDLPack, SupportsBufferProtocol, ], /, @@ -60,7 +58,9 @@ def asarray( if copy in (False, np._CopyMode.IF_NEEDED): # Note: copy=False is not yet implemented in np.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): + if isinstance(obj, Array): + if dtype is not None and obj.dtype != dtype: + copy = True if copy in (True, np._CopyMode.ALWAYS): return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj @@ -152,8 +152,9 @@ def eye( def from_dlpack(x: object, /) -> Array: - # Note: dlpack support is not yet implemented on Array - raise NotImplementedError("DLPack support is not yet implemented") + from ._array_object import Array + + return Array._new(np._from_dlpack(x)) def full( @@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: """ from ._array_object import Array + # Note: unlike np.meshgrid, only inputs with all the same dtype are + # allowed + + if len({a.dtype for a in arrays}) > 1: + raise ValueError("meshgrid inputs must all have the same dtype") + return [ Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) |