summaryrefslogtreecommitdiff
path: root/numpy/array_api/_creation_functions.py
diff options
context:
space:
mode:
authorRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
committerRalf Gommers <ralf.gommers@gmail.com>2021-11-12 16:15:35 +0100
commiteccb8dfbd9b07183e16a1144e8d5d76936671bfc (patch)
tree647a9477b4f3b8b7205f2f7f2feb99eaa482e806 /numpy/array_api/_creation_functions.py
parentd0d75f39f28ac26d4cc1aa3a4cbea63a6a027929 (diff)
parentff2e2a1e7eea29d925063b13922e096d14331222 (diff)
downloadnumpy-eccb8dfbd9b07183e16a1144e8d5d76936671bfc.tar.gz
Merge branch 'main' into never_copy
Diffstat (limited to 'numpy/array_api/_creation_functions.py')
-rw-r--r--numpy/array_api/_creation_functions.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index d760bf2fc..741498ff6 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
Device,
Dtype,
NestedSequence,
- SupportsDLPack,
SupportsBufferProtocol,
)
from collections.abc import Sequence
@@ -36,7 +35,6 @@ def asarray(
int,
float,
NestedSequence[bool | int | float],
- SupportsDLPack,
SupportsBufferProtocol,
],
/,
@@ -60,7 +58,9 @@ def asarray(
if copy in (False, np._CopyMode.IF_NEEDED):
# Note: copy=False is not yet implemented in np.asarray
raise NotImplementedError("copy=False is not yet implemented")
- if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
+ if isinstance(obj, Array):
+ if dtype is not None and obj.dtype != dtype:
+ copy = True
if copy in (True, np._CopyMode.ALWAYS):
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
@@ -152,8 +152,9 @@ def eye(
def from_dlpack(x: object, /) -> Array:
- # Note: dlpack support is not yet implemented on Array
- raise NotImplementedError("DLPack support is not yet implemented")
+ from ._array_object import Array
+
+ return Array._new(np._from_dlpack(x))
def full(
@@ -240,6 +241,12 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
"""
from ._array_object import Array
+ # Note: unlike np.meshgrid, only inputs with all the same dtype are
+ # allowed
+
+ if len({a.dtype for a in arrays}) > 1:
+ raise ValueError("meshgrid inputs must all have the same dtype")
+
return [
Array._new(array)
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)