diff options
Diffstat (limited to 'numpy/array_api/_creation_functions.py')
-rw-r--r-- | numpy/array_api/_creation_functions.py | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index e36807468..741498ff6 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: Device, Dtype, NestedSequence, - SupportsDLPack, SupportsBufferProtocol, ) from collections.abc import Sequence @@ -36,14 +35,13 @@ def asarray( int, float, NestedSequence[bool | int | float], - SupportsDLPack, SupportsBufferProtocol, ], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: Optional[bool] = None, + copy: Optional[Union[bool, np._CopyMode]] = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`. @@ -57,11 +55,13 @@ def asarray( _check_valid_dtype(dtype) if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - if copy is False: + if copy in (False, np._CopyMode.IF_NEEDED): # Note: copy=False is not yet implemented in np.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): - if copy is True: + if isinstance(obj, Array): + if dtype is not None and obj.dtype != dtype: + copy = True + if copy in (True, np._CopyMode.ALWAYS): return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): @@ -152,8 +152,9 @@ def eye( def from_dlpack(x: object, /) -> Array: - # Note: dlpack support is not yet implemented on Array - raise NotImplementedError("DLPack support is not yet implemented") + from ._array_object import Array + + return Array._new(np._from_dlpack(x)) def full( @@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: """ from ._array_object import Array + # Note: unlike np.meshgrid, only inputs with all the same dtype are + # allowed + + if len({a.dtype for a in arrays}) > 1: + raise ValueError("meshgrid inputs must all have the same dtype") + return [ Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) |