summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/_creation_functions.py6
-rw-r--r--numpy/array_api/tests/test_creation_functions.py16
2 files changed, 14 insertions, 8 deletions
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index 23beec444..741498ff6 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -41,7 +41,7 @@ def asarray(
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
- copy: Optional[bool] = None,
+ copy: Optional[Union[bool, np._CopyMode]] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
@@ -55,13 +55,13 @@ def asarray(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
- if copy is False:
+ if copy in (False, np._CopyMode.IF_NEEDED):
# Note: copy=False is not yet implemented in np.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, Array):
if dtype is not None and obj.dtype != dtype:
copy = True
- if copy is True:
+ if copy in (True, np._CopyMode.ALWAYS):
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py
index ebbb6aab3..be9eaa383 100644
--- a/numpy/array_api/tests/test_creation_functions.py
+++ b/numpy/array_api/tests/test_creation_functions.py
@@ -43,12 +43,18 @@ def test_asarray_copy():
a[0] = 0
assert all(b[0] == 1)
assert all(a[0] == 0)
- # Once copy=False is implemented, replace this with
- # a = asarray([1])
- # b = asarray(a, copy=False)
- # a[0] = 0
- # assert all(b[0] == 0)
+ a = asarray([1])
+ b = asarray(a, copy=np._CopyMode.ALWAYS)
+ a[0] = 0
+ assert all(b[0] == 1)
+ assert all(a[0] == 0)
+ a = asarray([1])
+ b = asarray(a, copy=np._CopyMode.NEVER)
+ a[0] = 0
+ assert all(b[0] == 0)
assert_raises(NotImplementedError, lambda: asarray(a, copy=False))
+ assert_raises(NotImplementedError,
+ lambda: asarray(a, copy=np._CopyMode.IF_NEEDED))
def test_arange_errors():