diff options
author | Ralf Gommers <ralf.gommers@gmail.com> | 2021-11-12 16:15:35 +0100 |
---|---|---|
committer | Ralf Gommers <ralf.gommers@gmail.com> | 2021-11-12 16:15:35 +0100 |
commit | eccb8dfbd9b07183e16a1144e8d5d76936671bfc (patch) | |
tree | 647a9477b4f3b8b7205f2f7f2feb99eaa482e806 /numpy/array_api/_creation_functions.py | |
parent | d0d75f39f28ac26d4cc1aa3a4cbea63a6a027929 (diff) | |
parent | ff2e2a1e7eea29d925063b13922e096d14331222 (diff) | |
download | numpy-eccb8dfbd9b07183e16a1144e8d5d76936671bfc.tar.gz |
Merge branch 'main' into never_copy
Diffstat (limited to 'numpy/array_api/_creation_functions.py')
-rw-r--r-- | numpy/array_api/_creation_functions.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index d760bf2fc..741498ff6 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: Device, Dtype, NestedSequence, - SupportsDLPack, SupportsBufferProtocol, ) from collections.abc import Sequence @@ -36,7 +35,6 @@ def asarray( int, float, NestedSequence[bool | int | float], - SupportsDLPack, SupportsBufferProtocol, ], /, @@ -60,7 +58,9 @@ def asarray( if copy in (False, np._CopyMode.IF_NEEDED): # Note: copy=False is not yet implemented in np.asarray raise NotImplementedError("copy=False is not yet implemented") - if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): + if isinstance(obj, Array): + if dtype is not None and obj.dtype != dtype: + copy = True if copy in (True, np._CopyMode.ALWAYS): return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj @@ -152,8 +152,9 @@ def eye( def from_dlpack(x: object, /) -> Array: - # Note: dlpack support is not yet implemented on Array - raise NotImplementedError("DLPack support is not yet implemented") + from ._array_object import Array + + return Array._new(np._from_dlpack(x)) def full( @@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]: """ from ._array_object import Array + # Note: unlike np.meshgrid, only inputs with all the same dtype are + # allowed + + if len({a.dtype for a in arrays}) > 1: + raise ValueError("meshgrid inputs must all have the same dtype") + return [ Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) |